The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.
PiperOrigin-RevId: 716596848
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
and multiple devices.
Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.
Fixes https://github.com/jax-ml/jax/issues/25671.
PiperOrigin-RevId: 716309978
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.
PiperOrigin-RevId: 716219730
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
The motivation behind this change is twofold:
1. it simplifies test writing (no need to produce arbitrary, manual, non-splat
constants to produce arguments with a strided layout);
2. it'll allow running layout inference on different `FuncOp`s in isolation,
before inlining.
While the primary motivation is to simplify test writing for upcoming changes,
`2.` is useful if we ever intend to call functions whose body's layout we have
inferred from other functions. It's not clear to me that we have a use case for
that, but the theoretical benefit is worth pointing out.
Crucially, layout inference does not set default layouts for `FuncOp`s, since
the caller may choose a different layout for its arguments. As a result, there
is also no layout inference rule for `func.FuncOp`.
PiperOrigin-RevId: 716158516
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.
This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.
PiperOrigin-RevId: 715837855
The error message produced by MLIR is not really clear, but AFAICT the crash
was caused by the "temporary module" hack we use in the lax.cond lowering
rule.
PiperOrigin-RevId: 715785632
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.
Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
value of the `use_shardy_partitioner` feature flag.
Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that.
PiperOrigin-RevId: 715293018