[Pallas:MGPU] Add support for grid dims in GPUMesh

Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.

PiperOrigin-RevId: 688488660
This commit is contained in:
Adam Paszke 2024-10-22 04:10:11 -07:00 committed by jax authors
parent 0b3f0e11fb
commit 2db03ba54b
3 changed files with 35 additions and 6 deletions

View File

@ -466,13 +466,22 @@ class GPUMesh:
"Requested too many CUDA threads per block. Each Mosaic thread"
" corresponds to 128 CUDA threads."
)
if self.cluster:
raise NotImplementedError(
"Pallas/MosaicGPU does not support clusters yet."
)
@property
def shape(self):
if self.num_threads is not None:
pairs = zip(self.axis_names, (*self.grid, self.num_threads))
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
else:
pairs = (*zip(self.axis_names, self.grid), (_WARPGROUP_AXIS_NAME, 1))
pairs = tuple(
zip(
(*self.axis_names, _WARPGROUP_AXIS_NAME),
(*self.grid, *self.cluster, 1),
)
)
return collections.OrderedDict(pairs)
@ -485,11 +494,10 @@ def _gpu_mesh_discharge_rule(
):
del out_avals
assert isinstance(mesh, GPUMesh)
if mesh.grid or mesh.cluster:
if mesh.cluster:
raise NotImplementedError
if mesh.num_threads is None:
raise NotImplementedError
threads_axis_name, num_threads = list(mesh.shape.items())[0]
def body(*args):
# Due to aliasing, args contains aliased inputs and outputs so we remove
# outputs.
@ -503,7 +511,7 @@ def _gpu_mesh_discharge_rule(
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((threads_axis_name, num_threads),),
grid=tuple(mesh.shape.items()),
)(*args)
return out, ()

View File

@ -1125,7 +1125,6 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
if axis_name == grid_names[-1]:
return mgpu.warpgroup_idx(sync=False)
else:
raise NotImplementedError # The code below is untested
idx = grid_names.index(axis_name)
return arith_dialect.index_cast(
ir.IntegerType.get_signless(32),

View File

@ -1025,6 +1025,28 @@ class CoreMapTest(PallasTest):
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
)
def test_multiple_wg_with_grid(self):
mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg"))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
xy_idx = jax.lax.axis_index(("x", "y"))
yx_idx = jax.lax.axis_index(("y", "x"))
wg_idx = jax.lax.axis_index("wg")
num_wgs = jax.lax.psum(1, "wg")
y_ref[xy_idx, wg_idx] = jnp.broadcast_to(
yx_idx * num_wgs + wg_idx, (128,)
)
y_init = jnp.zeros((4, 2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
)
if __name__ == "__main__":
absltest.main()