The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
Memory barriers are necessary to prevent excessive run ahead in a collective
pipeline, but the implementation can be tricky (both in terms of calculating
the right arrival count and dividing the signalling responsibility between
threads). I largely tried to follow the practices that CUTLASS established,
although I still do not understand why it swizzles the cluster for signalling.
PiperOrigin-RevId: 655098234
Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.
PiperOrigin-RevId: 655077439
rather than failing an assert with no message
We will likely never support unmapped outputs and reverse-mode autodiff (ie
grad or vjp) with pmap, but it can be done with shard_map.
fixes#14296
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.
PiperOrigin-RevId: 654970654
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.
This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)
PiperOrigin-RevId: 654935883
It's unused, buggy (will return a reference to local copy of array) and `ArrayRef` already has a ctor that takes a `std::array`
PiperOrigin-RevId: 654916697
Previously, the `np.take(tile_assignment_devices, permute_order)` did not maintain the invariant of maintaining the concrete device order after permutation (per the old array).
But doing `np.take(permute_order, tile_assignment_devices)` maintains that invariant and hence is the correct thing to do.
PiperOrigin-RevId: 654884965
We want to allow `psum(x, axes)` regardless of how `x` is replicated. That
means when we rewrite it into the stricter `psum2`, which can only sum over
non-replicated axes, we need to insert a pbroadcast like this:
```
psum(x, axes) == psum2(pbroadcast(x, axes & input_replicated_axes), axes)
```
In words, we need to first `pbroadcast` over all those axes we're about to sum
over but that the input is already replicated over.
We write it as a comprehension over mesh.axis_names, rather than just that set
intersection, just to ensure deterministic ordering, since Python set
operations are not guaranteed to be deterministic. There are other places in
the file where we don't ensure deterministic ordering; someday I'll come back
and fix those.
fixes#19175