[pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test

Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing
indexing at all!

PiperOrigin-RevId: 698442820
This commit is contained in:
Sergei Lebedev 2024-11-20 10:41:24 -08:00 committed by jax authors
parent 2c9b917b9d
commit 9584ee3bb9

View File

@ -1218,33 +1218,33 @@ class PipelineTest(PallasTest):
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
def test_emit_with_parallel_grid(self):
self.skipTest("Enable once we support multiple levels of indexing")
num_steps = 4
num_steps1 = 4
num_steps2 = 5
def kernel(x_gmem, o_gmem):
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
pid = pl.program_id(0)
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
grid=(num_steps,),
in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
grid=(num_steps2,),
max_concurrent_steps=2,
)(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice])
)(x_gmem, o_gmem)
def kernel_body(x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(4 * 32 * num_steps * 16)
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
x = jnp.arange(num_steps1 * 32 * num_steps2 * 16)
x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(4, 1),
grid=(num_steps1,),
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
y = x + 1.0
np.testing.assert_array_equal(kernel_fn(x), y)
def test_emit_with_2d_grid(self):
num_steps1 = 4