diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2e2694452..c4e785e50 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -466,13 +466,22 @@ class GPUMesh: "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." ) + if self.cluster: + raise NotImplementedError( + "Pallas/MosaicGPU does not support clusters yet." + ) @property def shape(self): if self.num_threads is not None: - pairs = zip(self.axis_names, (*self.grid, self.num_threads)) + pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads)) else: - pairs = (*zip(self.axis_names, self.grid), (_WARPGROUP_AXIS_NAME, 1)) + pairs = tuple( + zip( + (*self.axis_names, _WARPGROUP_AXIS_NAME), + (*self.grid, *self.cluster, 1), + ) + ) return collections.OrderedDict(pairs) @@ -485,11 +494,10 @@ def _gpu_mesh_discharge_rule( ): del out_avals assert isinstance(mesh, GPUMesh) - if mesh.grid or mesh.cluster: + if mesh.cluster: raise NotImplementedError if mesh.num_threads is None: raise NotImplementedError - threads_axis_name, num_threads = list(mesh.shape.items())[0] def body(*args): # Due to aliasing, args contains aliased inputs and outputs so we remove # outputs. @@ -503,7 +511,7 @@ def _gpu_mesh_discharge_rule( in_specs=[any_spec] * len(in_avals), out_specs=[any_spec] * len(in_avals), input_output_aliases={i: i for i in range(len(in_avals))}, - grid=((threads_axis_name, num_threads),), + grid=tuple(mesh.shape.items()), )(*args) return out, () diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 32c1b1dcf..ec5584233 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1125,7 +1125,6 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): if axis_name == grid_names[-1]: return mgpu.warpgroup_idx(sync=False) else: - raise NotImplementedError # The code below is untested idx = grid_names.index(axis_name) return arith_dialect.index_cast( ir.IntegerType.get_signless(32), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 4b86cc21f..22ae3e699 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1025,6 +1025,28 @@ class CoreMapTest(PallasTest): f(), np.repeat(np.arange(2), 128).reshape(2, 128) ) + def test_multiple_wg_with_grid(self): + mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg")) + + @jax.jit + def f(): + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def kernel(): + xy_idx = jax.lax.axis_index(("x", "y")) + yx_idx = jax.lax.axis_index(("y", "x")) + wg_idx = jax.lax.axis_index("wg") + num_wgs = jax.lax.psum(1, "wg") + y_ref[xy_idx, wg_idx] = jnp.broadcast_to( + yx_idx * num_wgs + wg_idx, (128,) + ) + y_init = jnp.zeros((4, 2, 128), np.int32) + return inner(y_init) + np.testing.assert_array_equal( + f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) + ) + if __name__ == "__main__": absltest.main()