mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas:MGPU] Add support for grid dims in GPUMesh
Of course no communication can happen across grid dimensions (unlike over the WG dim), but we need to be able to launch multiple blocks somehow. PiperOrigin-RevId: 688488660
This commit is contained in:
parent
0b3f0e11fb
commit
2db03ba54b
@ -466,13 +466,22 @@ class GPUMesh:
|
||||
"Requested too many CUDA threads per block. Each Mosaic thread"
|
||||
" corresponds to 128 CUDA threads."
|
||||
)
|
||||
if self.cluster:
|
||||
raise NotImplementedError(
|
||||
"Pallas/MosaicGPU does not support clusters yet."
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if self.num_threads is not None:
|
||||
pairs = zip(self.axis_names, (*self.grid, self.num_threads))
|
||||
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
|
||||
else:
|
||||
pairs = (*zip(self.axis_names, self.grid), (_WARPGROUP_AXIS_NAME, 1))
|
||||
pairs = tuple(
|
||||
zip(
|
||||
(*self.axis_names, _WARPGROUP_AXIS_NAME),
|
||||
(*self.grid, *self.cluster, 1),
|
||||
)
|
||||
)
|
||||
return collections.OrderedDict(pairs)
|
||||
|
||||
|
||||
@ -485,11 +494,10 @@ def _gpu_mesh_discharge_rule(
|
||||
):
|
||||
del out_avals
|
||||
assert isinstance(mesh, GPUMesh)
|
||||
if mesh.grid or mesh.cluster:
|
||||
if mesh.cluster:
|
||||
raise NotImplementedError
|
||||
if mesh.num_threads is None:
|
||||
raise NotImplementedError
|
||||
threads_axis_name, num_threads = list(mesh.shape.items())[0]
|
||||
def body(*args):
|
||||
# Due to aliasing, args contains aliased inputs and outputs so we remove
|
||||
# outputs.
|
||||
@ -503,7 +511,7 @@ def _gpu_mesh_discharge_rule(
|
||||
in_specs=[any_spec] * len(in_avals),
|
||||
out_specs=[any_spec] * len(in_avals),
|
||||
input_output_aliases={i: i for i in range(len(in_avals))},
|
||||
grid=((threads_axis_name, num_threads),),
|
||||
grid=tuple(mesh.shape.items()),
|
||||
)(*args)
|
||||
return out, ()
|
||||
|
||||
|
@ -1125,7 +1125,6 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
if axis_name == grid_names[-1]:
|
||||
return mgpu.warpgroup_idx(sync=False)
|
||||
else:
|
||||
raise NotImplementedError # The code below is untested
|
||||
idx = grid_names.index(axis_name)
|
||||
return arith_dialect.index_cast(
|
||||
ir.IntegerType.get_signless(32),
|
||||
|
@ -1025,6 +1025,28 @@ class CoreMapTest(PallasTest):
|
||||
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
|
||||
)
|
||||
|
||||
def test_multiple_wg_with_grid(self):
|
||||
mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg"))
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
@pl.run_state
|
||||
def inner(y_ref):
|
||||
@pl.core_map(mesh)
|
||||
def kernel():
|
||||
xy_idx = jax.lax.axis_index(("x", "y"))
|
||||
yx_idx = jax.lax.axis_index(("y", "x"))
|
||||
wg_idx = jax.lax.axis_index("wg")
|
||||
num_wgs = jax.lax.psum(1, "wg")
|
||||
y_ref[xy_idx, wg_idx] = jnp.broadcast_to(
|
||||
yx_idx * num_wgs + wg_idx, (128,)
|
||||
)
|
||||
y_init = jnp.zeros((4, 2, 128), np.int32)
|
||||
return inner(y_init)
|
||||
np.testing.assert_array_equal(
|
||||
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user