Merge pull request #22440 from gnecula:pallas_test_clean

PiperOrigin-RevId: 652513116
This commit is contained in:
jax authors 2024-07-15 09:55:38 -07:00
commit 7255ab146b

View File

@ -279,7 +279,6 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
],
grid=8,
),
debug=False,
)(program, x)
expected = x
@ -676,11 +675,10 @@ class PallasCallDynamicGridInterpreterTest(PallasCallDynamicGridTest):
class PallasCallDMATest(PallasBaseTest):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
super().setUp()
def test_can_have_unspecified_memory_spaces(self):
def kernel(x_ref, y_ref):
# Just test whether things compile
@ -1483,13 +1481,12 @@ class PallasCallDMATest(PallasBaseTest):
class PallasCallRemoteDMATest(parameterized.TestCase):
def setUp(self):
super().setUp()
if jax.device_count() < 2:
self.skipTest('Only >=2 devices are supported.')
if not jtu.is_device_tpu_at_least(5):
self.skipTest('Only works with TPU v5')
super().setUp()
@parameterized.named_parameters(
('vmem', pltpu.TPUMemorySpace.VMEM),
('hbm', pltpu.TPUMemorySpace.ANY),
@ -1692,7 +1689,7 @@ class PallasCallTest(PallasBaseTest):
def kernel(x, y):
y[:] = x[:]
x = jnp.arange(1024.).reshape(8, 128)
f = pl.pallas_call(
f = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
compiler_params=dict(
@ -1716,12 +1713,12 @@ class PallasCallTest(PallasBaseTest):
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
with self.assertRaises(xla_extension.XlaRuntimeError):
pl.pallas_call(
self.pallas_call(
kernel,
out_shape=x,
compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)),
)(x)
pl.pallas_call(
self.pallas_call(
kernel,
out_shape=x,
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
@ -1735,7 +1732,7 @@ class PallasCallTest(PallasBaseTest):
def f(x, y):
z = jax.numpy.add(x, y)
return pl.pallas_call(
return self.pallas_call(
kernel,
grid=(3,),
in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))],
@ -1766,7 +1763,7 @@ class PallasCallTest(PallasBaseTest):
f'Requested internal scratch size {requested_bytes} needs to be at'
' least',
):
pl.pallas_call(
self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
compiler_params=dict(
@ -1914,12 +1911,6 @@ class PallasUXTest(PallasBaseTest):
class PallasMegacoreTest(PallasBaseTest):
def setUp(self):
if jtu.device_under_test() != 'tpu':
self.skipTest('Test only works on TPU')
super().setUp()
def test_megacore_splitting(self):
# We want to make sure a 3-sized dimension is split across megacore
# correctly, and if we combine the (3, 3) dimensions together it is still
@ -1993,11 +1984,10 @@ class PallasCallVmapTest(PallasBaseTest):
class PallasCallDynamicDMATest(PallasBaseTest):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
super().setUp()
def test_simple_tile_aligned_dynamic_size_dma(self):
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):