* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.
* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.
We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.
PiperOrigin-RevId: 717588253
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.
PiperOrigin-RevId: 716771872
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!
* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.
* src and dst mesh axis_names and axis_sizes should be the same.
TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
and multiple devices.
Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.
Fixes https://github.com/jax-ml/jax/issues/25671.
PiperOrigin-RevId: 716309978
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's).
PiperOrigin-RevId: 714681941
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.
Extend custom_partitioner tests in pjit_test.py for Shardy sharding rule.
PiperOrigin-RevId: 713399604
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.
In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.
PiperOrigin-RevId: 713272197
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.
Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.
PiperOrigin-RevId: 707128096
The test does not clear the JAX caches, and jax.sin is a jitted closure
that's shared between all test methods, so there's no guarantee that someone
hasn't already traced sine at that same shape before. This only shows up rarely
since it depends on the subset of tests assigned to the same test executor.
PiperOrigin-RevId: 706706380
This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered.
PiperOrigin-RevId: 705764178
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.
During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.
PiperOrigin-RevId: 704911253
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.
PiperOrigin-RevId: 700537898