* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
* `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.
* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.
* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.
* Rest of the changes are to make all other functions play nicely with sharding instances.
PiperOrigin-RevId: 461260553
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.
Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:
>>> params = { .. some params .. }
>>> f = jax.jit(..).lower(params).compile()
>>> f(params) # fine
>>> params['some_new_thing'] = something
>>> f(params)
TypeError: function compiled for {treedef}, called with {treedef}.
PiperOrigin-RevId: 461190971
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
* As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.
* Made all auto sharding tests parameterized to test both gda and array.
PiperOrigin-RevId: 459776152
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.
* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
* This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
* cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.
* Checking of sharding with `aval` has a handler system to deal with sharding instances.
* The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.
* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.
* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
* Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
* Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.
* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
* `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998
* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
* MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.
PiperOrigin-RevId: 459548974