`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.
`grid` now defaults to `()` instead of `None`.
Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.
Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.
Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.
Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.
Removed some dead code for implementing the interpret mode.
Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.
Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.
Added test for the calling convention, including dynamic grid dimensions.
There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
This is slightly less convenient than our previous approach but it has two main upsides:
1. It lets us automatically emit necessary fences and barriers for use with block clusters
2. It lets us share the same block/cluster barrier for all initializations of mbarriers
This change also moves away from the nvgpu dialect for barriers and allocates them in
dynamic SMEM instead of relying on static SMEM. This should give us more control over
SMEM layouts and alignments, and simplifies the lowering process.
PiperOrigin-RevId: 655493451
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 655484166
We don't use this capability just yet, but I want to start allocating barriers
as part of the scratch and this will push the unions deeper into the tree.
PiperOrigin-RevId: 655475839
We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout).
PiperOrigin-RevId: 655350013
This cl removes the funcOp from RewriteContext of apply-vector-layout-pass (since only one function is using it) and uses context to create the pass instead of a long list of arguments. We will need to add more args (target's bank counts) to create apply-vector-layout.
PiperOrigin-RevId: 655329321
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
Memory barriers are necessary to prevent excessive run ahead in a collective
pipeline, but the implementation can be tricky (both in terms of calculating
the right arrival count and dividing the signalling responsibility between
threads). I largely tried to follow the practices that CUTLASS established,
although I still do not understand why it swizzles the cluster for signalling.
PiperOrigin-RevId: 655098234
Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.
PiperOrigin-RevId: 655077439
rather than failing an assert with no message
We will likely never support unmapped outputs and reverse-mode autodiff (ie
grad or vjp) with pmap, but it can be done with shard_map.
fixes#14296
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.
PiperOrigin-RevId: 654970654
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.
This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)
PiperOrigin-RevId: 654935883