mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22440 from gnecula:pallas_test_clean
PiperOrigin-RevId: 652513116
This commit is contained in:
commit
7255ab146b
@ -279,7 +279,6 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
],
|
||||
grid=8,
|
||||
),
|
||||
debug=False,
|
||||
)(program, x)
|
||||
|
||||
expected = x
|
||||
@ -676,11 +675,10 @@ class PallasCallDynamicGridInterpreterTest(PallasCallDynamicGridTest):
|
||||
class PallasCallDMATest(PallasBaseTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest('DMAs not supported on TPU generations <= 3')
|
||||
|
||||
super().setUp()
|
||||
|
||||
def test_can_have_unspecified_memory_spaces(self):
|
||||
def kernel(x_ref, y_ref):
|
||||
# Just test whether things compile
|
||||
@ -1483,13 +1481,12 @@ class PallasCallDMATest(PallasBaseTest):
|
||||
class PallasCallRemoteDMATest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jax.device_count() < 2:
|
||||
self.skipTest('Only >=2 devices are supported.')
|
||||
if not jtu.is_device_tpu_at_least(5):
|
||||
self.skipTest('Only works with TPU v5')
|
||||
|
||||
super().setUp()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('vmem', pltpu.TPUMemorySpace.VMEM),
|
||||
('hbm', pltpu.TPUMemorySpace.ANY),
|
||||
@ -1692,7 +1689,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
def kernel(x, y):
|
||||
y[:] = x[:]
|
||||
x = jnp.arange(1024.).reshape(8, 128)
|
||||
f = pl.pallas_call(
|
||||
f = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
compiler_params=dict(
|
||||
@ -1716,12 +1713,12 @@ class PallasCallTest(PallasBaseTest):
|
||||
|
||||
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
with self.assertRaises(xla_extension.XlaRuntimeError):
|
||||
pl.pallas_call(
|
||||
self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)),
|
||||
)(x)
|
||||
pl.pallas_call(
|
||||
self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
|
||||
@ -1735,7 +1732,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
|
||||
def f(x, y):
|
||||
z = jax.numpy.add(x, y)
|
||||
return pl.pallas_call(
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
grid=(3,),
|
||||
in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))],
|
||||
@ -1766,7 +1763,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
f'Requested internal scratch size {requested_bytes} needs to be at'
|
||||
' least',
|
||||
):
|
||||
pl.pallas_call(
|
||||
self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
||||
compiler_params=dict(
|
||||
@ -1914,12 +1911,6 @@ class PallasUXTest(PallasBaseTest):
|
||||
|
||||
class PallasMegacoreTest(PallasBaseTest):
|
||||
|
||||
def setUp(self):
|
||||
if jtu.device_under_test() != 'tpu':
|
||||
self.skipTest('Test only works on TPU')
|
||||
|
||||
super().setUp()
|
||||
|
||||
def test_megacore_splitting(self):
|
||||
# We want to make sure a 3-sized dimension is split across megacore
|
||||
# correctly, and if we combine the (3, 3) dimensions together it is still
|
||||
@ -1993,11 +1984,10 @@ class PallasCallVmapTest(PallasBaseTest):
|
||||
class PallasCallDynamicDMATest(PallasBaseTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest('DMAs not supported on TPU generations <= 3')
|
||||
|
||||
super().setUp()
|
||||
|
||||
def test_simple_tile_aligned_dynamic_size_dma(self):
|
||||
|
||||
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
|
||||
|
Loading…
x
Reference in New Issue
Block a user