158 Commits

Author SHA1 Message Date
Peter Hawkins
5527966b27 [JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8 Enable fast path in the Array constructor. This means that the rearranging of _arrays according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Yash Katariya
ff7cd9a136 Add the autosharding error test back but disable it for now.
PiperOrigin-RevId: 468699560
2022-08-19 07:44:43 -07:00
Yash Katariya
c55e4dc207 Add support for doing grad of pjit (similar to what jit supports). Resolve in_shardings in _pjit_call_impl (that were UNSPECIFIED) before lowering to XLA. Then check if the device assignment is same across shardings in lower_sharding_computation.
PiperOrigin-RevId: 468251065
2022-08-17 11:28:08 -07:00
Yash Katariya
022d92b791 Add support for giving sharding instances as input to with_sharding_constraint.
PiperOrigin-RevId: 467924064
2022-08-16 07:51:53 -07:00
Yash Katariya
3d2026a1d0 Create a SameDeviceAssignmentTuple type to cache the op shardings and device assignment. But the device_assignment is only cached once because pjit checks if all device assignments are equal or not right at the start.
PiperOrigin-RevId: 467051286
2022-08-11 14:35:59 -07:00
Parker Schuh
310c7a2934 Make vmap axis actually unconstrained by default. 2022-08-11 11:06:12 -07:00
Parker Schuh
8fb957350c Add spmd_axis_name to vmap to allow constraining mapped PartitionSpecs. 2022-08-08 19:41:42 -07:00
Yash Katariya
480efcf0ee Add a test for simulating training run _pjit_lower cache hit in pjit.
PiperOrigin-RevId: 466055367
2022-08-08 08:56:06 -07:00
Yash Katariya
c02359b924 Add early support in pjit for single device shardings. Also lift the restriction of needing the mesh context manager when config.jax_array is enabled.
PiperOrigin-RevId: 465712981
2022-08-05 22:25:25 -07:00
Yash Katariya
a427dc9ecb Treat all shardings on a single device as equivalent.
PiperOrigin-RevId: 465685287
2022-08-05 18:02:03 -07:00
Yash Katariya
4b6d4a4ef7 Don't depend on mesh for UNSPECIFIED. Use OpShardingSharding for that since its now available and pjit accepts it.
PiperOrigin-RevId: 465641117
2022-08-05 13:55:21 -07:00
Yash Katariya
007d651ac8 Canonicalize all shardings to OpShardingSharding throughout pjit. Places where pspec is needed, parsed_flatten_op_sharding function is used to retrieve the pspec. The major places are global_to_local and local_to_global. Rest of the changes are just threading through OpShardingSharding.
I have added comments to places to explain things.

Dependence on MeshPspecSharding in Partial eval has been removed. It now depends on OpShardingSharding.

TODO: Fix the round trip through MeshPspecSharding in vmap batching handlers.
PiperOrigin-RevId: 465621165
2022-08-05 12:18:10 -07:00
Yash Katariya
9945510de3 Fix the __eq__ check of OpShardingSharding and add a test to check the cache for device_indices_map. Also, check for OpSharding replication via the is_op_sharding_replicated function.
PiperOrigin-RevId: 465586777
2022-08-05 09:59:56 -07:00
Yash Katariya
47623264db Export HloSharding via pybind which is a C++ wrapper around OpSharding proto.
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
Yash Katariya
97f2f4efa4 Cache the replacement of FROM_GDA to actual shardings present on GDA and check for sharding equality via opsharding.
PiperOrigin-RevId: 463388009
2022-07-26 11:33:57 -07:00
Yash Katariya
b42c84f26f Add a opsharding equality function until HLOSharding class is exported via pybind. The equality behavior is the same as HloSharding.
PiperOrigin-RevId: 463162918
2022-07-25 13:24:33 -07:00
Yash Katariya
ea1593a9b2 Make the _check_shapes_against_resources check general for all XLACompatibleShardings by looking at the opsharding proto of the shardings.
PiperOrigin-RevId: 463161459
2022-07-25 13:18:18 -07:00
Yash Katariya
d8cbb29d14 OpSharding doesn't have __eq__ defined on it. Don't check sharding equality using opsharding until it does support that.
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Yash Katariya
90687cc1ff Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation except of AUTO and UNSPECIFIED (see below for these 2).
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
  * `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.

* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.

* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.

* Rest of the changes are to make all other functions play nicely with sharding instances.

PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
Yash Katariya
0bc8f8abeb * Check if the device assignment is the same across input and output shardings.
* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.

PiperOrigin-RevId: 460326054
2022-07-11 16:27:11 -07:00
Yash Katariya
09ba51f323 Move _get_array_mapping from gda.py to pxla.py
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
Yash Katariya
bb2c5f111a Resolve TODOs and add some more checks for the jax.Array path.
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
Yash Katariya
229ddecc45 * Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED singleton value.
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
  * As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.

* Made all auto sharding tests parameterized to test both gda and array.

PiperOrigin-RevId: 459776152
2022-07-08 09:45:23 -07:00
jax authors
55dcbec5b5 Merge pull request #11407 from hawkinsp:minver
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
jax authors
7ffedb5815 Merge pull request #11400 from jakevdp:deprecate-treeutil
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Yash Katariya
7da733f94b Change the internals of with_sharding_constraint to use the sharding instances.
PiperOrigin-RevId: 459600050
2022-07-07 14:22:10 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Yash Katariya
2314951669 Convert everything in pjit to the Sharding interface. The following contains the things that have changed in this CL:
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.

* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
  * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
  * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.

* Checking of sharding with `aval` has a handler system to deal with sharding instances.
  * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.

* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.

* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
  * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
  * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.

* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
  * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998

* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
  * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.

PiperOrigin-RevId: 459548974
2022-07-07 10:41:52 -07:00
Roy Frostig
f12af93258 refactor stages types, adding methods for text and for cost/memory analyses
Re-organizing things this way in order to:

* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.

* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.

For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.

Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.

PiperOrigin-RevId: 458574166
2022-07-01 17:35:53 -07:00
Yash Katariya
e5031d15de Disable xla sharding propagation test for SE because XLA sharding propagation is not supported on SE which is activate when out_axis_resources is not specified in pjit.
PiperOrigin-RevId: 456391444
2022-06-21 17:40:31 -07:00
Jake VanderPlas
51da9eb237 [x64] make pjit_test compatible with strict dtype promotion 2022-06-17 16:25:36 -07:00
Yash Katariya
6ed94ef876 First CL to integrate jax.Array into pmap.
* If `config.jax_array` is enabled, output from pmap will be `Array`s.
* `Array`s are input are accepted by pmap (as shown in the test). Currently `pxla.make_sharded_device_array` creates SDAs specially for pmap here: https://github.com/google/jax/blob/main/jax/interpreters/pxla.py#L549. So a similar approach can be done for creating `Array`s specially for pmap (see the test).
Also `device_put_sharded` also creates SDAs for pmap.
* `Array`s that are output from `pmap` cannot be passed into `pjit` for now. Currently even SDAs from pmap that are passed into pjit are resharded which has a huge cost. So this kind of code is not used in majority anyways. I can look into relaxing this restriction in the future.

TODOs:
* Add checks for checking if pmap sharding matches the input arrays which I will add in a follow up CL immediately.
* Figure out how to use existing tests for pmap, pjit, xmap, etc.
PiperOrigin-RevId: 455519748
2022-06-16 19:52:31 -07:00
Yash Katariya
a7160653ce Add __array__ (for device_get), _npy_value, block_until_ready, delete and _check_if_deleted to Array.
PiperOrigin-RevId: 454741685
2022-06-13 18:08:31 -07:00
Yash Katariya
037b326453 Fix the CI failure caused by changes to XLA:GPU sharding propagation at HEAD.
PiperOrigin-RevId: 454293072
2022-06-10 20:46:49 -07:00
Yash Katariya
8e71cde065 Allow values in in_axis_resources when Array is an input and make it act like GDA until we remove in_axis_resources. This would also help in transitioning to Array.
PiperOrigin-RevId: 454287962
2022-06-10 19:55:47 -07:00
Yash Katariya
6e07d5c141 Allow optional out_axis_resources for jax.Array. I'll look into making this enabled for all types in a follow up CL. If out_axis_resources is _UNSPECIFIED, then we let classic spmd choose the sharding for us. If you want to use the AUTO spmd partitioner, then pass pjit.AUTO to out_axis_resources explicitly.
PiperOrigin-RevId: 454284276
2022-06-10 19:12:32 -07:00
Yash Katariya
be41c8c1d3 Add pjit support for Array. Array takes the same codepath as GDA so there are very little modifications to pjit. Add handlers aval, shard_args and result handlers for Array.
PiperOrigin-RevId: 454160854
2022-06-10 07:32:16 -07:00
Yash Katariya
f21898196d Support a mix of user sharding and auto sharding in pjit.
PiperOrigin-RevId: 453215791
2022-06-06 09:27:17 -07:00
Yash Katariya
998171d3c8 Fix the mypy type error. Turns out compiled.out_tree.unflatten is not a tuple but compiled.in_tree.unflatten is a tuple of args and kwargs.
PiperOrigin-RevId: 450533405
2022-05-23 15:02:28 -07:00
Yash Katariya
88207ae196 Allow usage of auto sharding only with global semantics.
If a user passes _global_avals=True to lower, then consider all inputs to have global semantics. This is because you can't convert all the inputs to ShapedArrays while using auto sharding. Sometimes they need to converted to a different thing which their train_step expects (see below for such an example).

PiperOrigin-RevId: 450501974
2022-05-23 12:39:19 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
5d45458c7b api_util: make shaped_abstractify respect raise_to_shaped 2022-05-05 17:20:00 -07:00
Yash Katariya
874374c762 Raise a better error when assert fails in mesh_sharding_specs
PiperOrigin-RevId: 445533883
2022-04-29 16:49:05 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
Yash Katariya
eda5bbb514 Expose the input and output sharding on the compiled object.
PiperOrigin-RevId: 441514572
2022-04-13 10:18:25 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00