[PALLAS] add test for large indexing.

PiperOrigin-RevId: 611925093
This commit is contained in:
Blake Hechtman 2024-03-01 22:39:42 -08:00 committed by jax authors
parent 51a31e505f
commit ab8346956c

View File

@ -1035,6 +1035,32 @@ class PallasCallDMATest(parameterized.TestCase):
)(x)
np.testing.assert_array_equal(y, x)
def test_large_array_indexing(self):
n = 6
dtype = jnp.bfloat16
x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0)
def kernel(index, x, y, sem):
pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait()
run = pl.pallas_call(kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(
memory_space=pltpu.TPUMemorySpace.ANY)],
out_specs=pl.BlockSpec(
memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA],
),
out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype),
)
for i in range(x.shape[0]):
y = run(jnp.array([i], dtype=jnp.int32), x)
np.testing.assert_array_equal(y, i)
del y
class PallasCallRemoteDMATest(parameterized.TestCase):