This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 659492696
The fact that src generalizes dst does not mean that they have the same implicit
tile shape (if one has an implicit dim and the other one doesn't, then they will
differ by a singleton dimension).
PiperOrigin-RevId: 658775019
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.
PiperOrigin-RevId: 658497960
We're constantly hitting unimpelmented relayouts, but it's hard to even know what's
in there given the way the code is written. This is the first of a few clean-up CLs
that aims to partition the process into steps with clear responsibilities. It should
help us better understand what's missing.
PiperOrigin-RevId: 658318811
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.
PiperOrigin-RevId: 658011228
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.
PiperOrigin-RevId: 658002501
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.
PiperOrigin-RevId: 657186057
This should help with understanding cuTensorMapEncodeTiled failures, since
CUDA doesn't provide any details beyond the error return code.
Note that this change also ensures that TMA descriptors are 64-byte aligned.
PiperOrigin-RevId: 656062820
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.
PiperOrigin-RevId: 655686102
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 655484166
We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout).
PiperOrigin-RevId: 655350013
This cl removes the funcOp from RewriteContext of apply-vector-layout-pass (since only one function is using it) and uses context to create the pass instead of a long list of arguments. We will need to add more args (target's bank counts) to create apply-vector-layout.
PiperOrigin-RevId: 655329321
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.
This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)
PiperOrigin-RevId: 654935883
It's unused, buggy (will return a reference to local copy of array) and `ArrayRef` already has a ctor that takes a `std::array`
PiperOrigin-RevId: 654916697
- Sublane unfolding was not being checked for non-empty implicit dims e.g. (2, 2, 128, 1) -> (2, 256) would not work
- Noop squeeze/unsqueeze paths in infer-vector-layout, when the source has ImplicitDim::kNone, were forcing native tiling for some reason
- 1D lane squeeze was always assigning bitwidth of 32.
- Maybe others
PiperOrigin-RevId: 653910942
* if bitwidth does not change after bitcast:
- We can bitcast the input with any vector layout.
* if bitwidth changes after bitcast:
- We can bitcast the input with sublane offset which is a multiple of the ratio of bandwidths.
PiperOrigin-RevId: 653375579
Also fix bug in (1, 128 * packing) -> (packing, 128) retiling where the part index could be incremented OOB.
Note: Many relayouts might be inefficient for implicit dims. If, for example, implicit dim is kSecondMinor, retiling might blend tiles that are only padding. This also applies to kNone implicit dim with small shapes, however, so any optimizations should be written based on the implicit shape.
PiperOrigin-RevId: 653209744
This cl supports memref shapecast:
1. if tile is (1, 128), we support shapecast on any dim.
2. if shapecast on sublane dim, we only support tile aligned shape.
3. if shapecast on non-tiling dim, we support any shapecast.
4. all other cases would be considered as invalid memref shapecast.
PiperOrigin-RevId: 651924552
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 651691430