241 Commits

Author SHA1 Message Date
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
jax authors
0c4c020716 Include compile time along with executable in cache entry.
In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit.

Testing: updated tests.
PiperOrigin-RevId: 549439628
2023-07-19 15:17:45 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
c287b2a1db Add missing %s parameter to error message in _cache_write.
PiperOrigin-RevId: 539188601
2023-06-09 15:12:38 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Yash Katariya
b196ad2e8c Remove the f-string evaluation during logging the elapsed time by passing in fun_name to log_elapsed_time
PiperOrigin-RevId: 532132574
2023-05-15 09:15:58 -07:00
Yash Katariya
8e1ad734bc Log the time it takes to lower from jaxpr to stableHLO
PiperOrigin-RevId: 532115098
2023-05-15 08:08:13 -07:00
Yash Katariya
1bef7c9787 Fix McJAX resharding when the input has a fully replicated sharding
PiperOrigin-RevId: 531263333
2023-05-11 11:42:36 -07:00
Yash Katariya
2694bf6207 Use set equality operators instead of intersection because I didn't know set had equality operators.
PiperOrigin-RevId: 530688786
2023-05-09 12:55:47 -07:00
Yash Katariya
18d19caa1c Add McJAX resharding to device_put. Allow resharding if inputs and target sharding have the same set of devices but different order.
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.

Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
2023-05-09 09:58:12 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Yash Katariya
b698390171 Handle multihost pmap in pmap shard_map merge. This involves lifting the host local inputs to global inputs and vice-versa on the outputs.
To handle Tracers, ShapedArray, concrete Arrays, etc `global_array_to_host_local_array` and `host_local_array_to_global_array` are now primitives.

PiperOrigin-RevId: 528925663
2023-05-02 16:53:22 -07:00
Yash Katariya
34d5a6259f Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.

PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
Peter Hawkins
1d63d9b833 Include the device_kind in the compilation cache key.
PiperOrigin-RevId: 525726898
2023-04-20 06:16:45 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
017548c40b Move implementation of compilation cache out of jax/experimental and into jax/_src.
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.

No functional change intended.

PiperOrigin-RevId: 524855263
2023-04-17 08:35:53 -07:00
Yash Katariya
febd339742 [Micro-optimization] Only log the avals and shardings if logging is enabled for that level.
PiperOrigin-RevId: 524845969
2023-04-17 07:53:37 -07:00
Yash Katariya
b06d627c05 Remove _allow_propagation_to_outputs from compile in MeshComputation since after jax.Array it is not required and can just default to being set to True if a sharding is unspecified.
PiperOrigin-RevId: 523851611
2023-04-12 17:38:18 -07:00
Yash Katariya
5d1abe1ba9 Make apply_primitive preserve shardings on outputs.
PiperOrigin-RevId: 523186148
2023-04-10 12:41:02 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
jax authors
beeac2cac2 Merge pull request #15497 from mattjj:issue15385
PiperOrigin-RevId: 523112617
2023-04-10 07:31:50 -07:00
Yash Katariya
5d2f453094 Preserve shardings on the output of pjit that were provided on the arguments.
Following are the changes:

* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
  * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
  * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
  * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.

The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.

For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.

PiperOrigin-RevId: 522991145
2023-04-09 15:42:11 -07:00
Matthew Johnson
e04409f088 [shard-map] fix jaxpr_shardings logic for shmap with no outputs
fixes #15385
2023-04-08 23:03:49 -07:00
Peter Hawkins
b4402185db Move PartitionSpec into its own file (jax/_src/partition_spec.py).
No functional changes intended.

A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.

PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Yash Katariya
c978df5dbb Delete unused functions from dispatch.py and pjit.py
PiperOrigin-RevId: 520730163
2023-03-30 13:38:44 -07:00
Peter Hawkins
23451dc764 Merge pull request #15303 from jakevdp:lax-asarray
PiperOrigin-RevId: 520717999
2023-03-30 20:11:11 +00:00
Yash Katariya
830cd9fd98 Delete _single_device_array_from_buf since everything from JAX is an Array
PiperOrigin-RevId: 520418231
2023-03-29 12:59:12 -07:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Yash Katariya
b5c9c0f47e Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
PiperOrigin-RevId: 518282464
2023-03-21 08:40:42 -07:00
Yash Katariya
d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00
Yash Katariya
6d0189e810 Remove dispatch.result_handlers since they are not used.
PiperOrigin-RevId: 517456171
2023-03-17 11:02:22 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Yash Katariya
f9468d3879 Remove the helper jit functions from api.py
PiperOrigin-RevId: 517152277
2023-03-16 10:08:00 -07:00
Yash Katariya
6a0c8069dc Remove the check for if not isinstance(old_token, array.ArrayImpl) since py_executable always return jax.Arrays
PiperOrigin-RevId: 516974728
2023-03-15 17:30:21 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Matthew Johnson
54b889ca7f [dynamic-shapes] don't require buf objects have dtype attribute
Fixes iree-org/iree-jax#57

An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.

I had to run the test with:

`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`
2023-03-15 12:53:43 -07:00
Yash Katariya
b97fb56e95 If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
PiperOrigin-RevId: 516561791
2023-03-14 10:19:37 -07:00
Yash Katariya
a71d09b950 Don't use the dispatch.device_put path since it is deprecated and will be removed soon.
PiperOrigin-RevId: 516370577
2023-03-13 17:39:28 -07:00
Parker Schuh
5aa74acbcd Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 516317920
2023-03-13 14:11:10 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Yash Katariya
a70c95641c Fix copy_array_to_devices_with_sharding by making it take a committed argument so that the Array created has the right semantics.
PiperOrigin-RevId: 515755664
2023-03-10 15:38:28 -08:00
Peter Hawkins
cca3961cde [JAX] Split _src/xla_bridge.py into a separate Bazel target.
Include _src/distributed.py and _src/clusters/*.py in the same target because they are in a strongly-connected component.

[XLA:Python] Set type of ArrayImpl to Any, since the JAX change now allows pytype to see that some values are ArrayImpls but ArrayImpls are not instances of jax.Array to Pytype.

Fix type of buffer_from_pyval.

PiperOrigin-RevId: 515687258
2023-03-10 11:12:02 -08:00
Yash Katariya
00b90e9073 [Rollback] Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 515659122
2023-03-10 09:36:18 -08:00
Yash Katariya
626221aaa2 Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 515493002
2023-03-09 18:01:05 -08:00
Yash Katariya
50408fd694 Use shard_args and global_result_handlers since the aval_to_result_handler and dispatch.device_put will be removed soon.
PiperOrigin-RevId: 515471662
2023-03-09 16:19:10 -08:00