mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
[pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test
Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing indexing at all! PiperOrigin-RevId: 698442820
This commit is contained in:
parent
2c9b917b9d
commit
9584ee3bb9
@ -1218,33 +1218,33 @@ class PipelineTest(PallasTest):
|
||||
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
|
||||
|
||||
def test_emit_with_parallel_grid(self):
|
||||
self.skipTest("Enable once we support multiple levels of indexing")
|
||||
|
||||
num_steps = 4
|
||||
num_steps1 = 4
|
||||
num_steps2 = 5
|
||||
|
||||
def kernel(x_gmem, o_gmem):
|
||||
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
|
||||
pid = pl.program_id(0)
|
||||
plgpu.emit_pipeline(
|
||||
kernel_body,
|
||||
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
|
||||
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
|
||||
grid=(num_steps,),
|
||||
in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
|
||||
out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
|
||||
grid=(num_steps2,),
|
||||
max_concurrent_steps=2,
|
||||
)(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice])
|
||||
)(x_gmem, o_gmem)
|
||||
|
||||
def kernel_body(x_smem, o_smem):
|
||||
o_smem[...] = x_smem[...] + 1.0
|
||||
|
||||
x = jnp.arange(4 * 32 * num_steps * 16)
|
||||
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
|
||||
x = jnp.arange(num_steps1 * 32 * num_steps2 * 16)
|
||||
x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32)
|
||||
kernel_fn = pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
||||
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
grid=(4, 1),
|
||||
grid=(num_steps1,),
|
||||
)
|
||||
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
||||
y = x + 1.0
|
||||
np.testing.assert_array_equal(kernel_fn(x), y)
|
||||
|
||||
def test_emit_with_2d_grid(self):
|
||||
num_steps1 = 4
|
||||
|
Loading…
x
Reference in New Issue
Block a user