* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
* As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.
* Made all auto sharding tests parameterized to test both gda and array.
PiperOrigin-RevId: 459776152
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.
PiperOrigin-RevId: 459566727
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.
* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
* This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
* cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.
* Checking of sharding with `aval` has a handler system to deal with sharding instances.
* The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.
* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.
* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
* Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
* Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.
* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
* `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998
* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
* MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.
PiperOrigin-RevId: 459548974