.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.
PiperOrigin-RevId: 469984029
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
I have added comments to places to explain things.
Dependence on MeshPspecSharding in Partial eval has been removed. It now depends on OpShardingSharding.
TODO: Fix the round trip through MeshPspecSharding in vmap batching handlers.
PiperOrigin-RevId: 465621165
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
* `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.
* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.
* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.
* Rest of the changes are to make all other functions play nicely with sharding instances.
PiperOrigin-RevId: 461260553
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
* As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.
* Made all auto sharding tests parameterized to test both gda and array.
PiperOrigin-RevId: 459776152
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.
* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
* This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
* cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.
* Checking of sharding with `aval` has a handler system to deal with sharding instances.
* The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.
* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.
* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
* Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
* Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.
* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
* `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998
* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
* MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.
PiperOrigin-RevId: 459548974
Re-organizing things this way in order to:
* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.
* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.
For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.
Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.
PiperOrigin-RevId: 458574166
* If `config.jax_array` is enabled, output from pmap will be `Array`s.
* `Array`s are input are accepted by pmap (as shown in the test). Currently `pxla.make_sharded_device_array` creates SDAs specially for pmap here: https://github.com/google/jax/blob/main/jax/interpreters/pxla.py#L549. So a similar approach can be done for creating `Array`s specially for pmap (see the test).
Also `device_put_sharded` also creates SDAs for pmap.
* `Array`s that are output from `pmap` cannot be passed into `pjit` for now. Currently even SDAs from pmap that are passed into pjit are resharded which has a huge cost. So this kind of code is not used in majority anyways. I can look into relaxing this restriction in the future.
TODOs:
* Add checks for checking if pmap sharding matches the input arrays which I will add in a follow up CL immediately.
* Figure out how to use existing tests for pmap, pjit, xmap, etc.
PiperOrigin-RevId: 455519748
If a user passes _global_avals=True to lower, then consider all inputs to have global semantics. This is because you can't convert all the inputs to ShapedArrays while using auto sharding. Sometimes they need to converted to a different thing which their train_step expects (see below for such an example).
PiperOrigin-RevId: 450501974