mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PALLAS] add test for large indexing.
PiperOrigin-RevId: 611925093
This commit is contained in:
parent
51a31e505f
commit
ab8346956c
@ -1035,6 +1035,32 @@ class PallasCallDMATest(parameterized.TestCase):
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_large_array_indexing(self):
|
||||
n = 6
|
||||
dtype = jnp.bfloat16
|
||||
x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0)
|
||||
|
||||
def kernel(index, x, y, sem):
|
||||
pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait()
|
||||
|
||||
run = pl.pallas_call(kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
in_specs=[
|
||||
pl.BlockSpec(
|
||||
memory_space=pltpu.TPUMemorySpace.ANY)],
|
||||
out_specs=pl.BlockSpec(
|
||||
memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
scratch_shapes=[pltpu.SemaphoreType.DMA],
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype),
|
||||
)
|
||||
|
||||
for i in range(x.shape[0]):
|
||||
y = run(jnp.array([i], dtype=jnp.int32), x)
|
||||
np.testing.assert_array_equal(y, i)
|
||||
del y
|
||||
|
||||
|
||||
class PallasCallRemoteDMATest(parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user