mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Temporarily disable broken tests in tpu_pallas_pipeline_test.py
PiperOrigin-RevId: 660972804
This commit is contained in:
parent
12a9c8cfd4
commit
deefbdd626
@ -139,6 +139,8 @@ class PallasCallPipelineTest(parameterized.TestCase):
|
||||
('hbm', pltpu.TPUMemorySpace.ANY),
|
||||
)
|
||||
def test_pipeline_matmul(self, memory_space):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
x = jax.random.uniform(k1, (512, 512))
|
||||
y = jax.random.uniform(k2, (512, 512))
|
||||
@ -184,6 +186,8 @@ class PallasCallPipelineTest(parameterized.TestCase):
|
||||
('hbm', pltpu.TPUMemorySpace.ANY),
|
||||
)
|
||||
def test_double_pipeline_matmul(self, memory_space):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
x = jax.random.uniform(k1, (512, 512))
|
||||
y = jax.random.uniform(k2, (512, 512))
|
||||
@ -535,6 +539,8 @@ class PallasCallCollectivePipelineTest(parameterized.TestCase):
|
||||
)
|
||||
def test_pipeline_throughput_optimized_allgather_matmul(
|
||||
self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
input_dtype = out_dtype
|
||||
num_devices = jax.local_device_count()
|
||||
|
||||
@ -1065,6 +1071,8 @@ class PallasCallCollectivePipelineTest(parameterized.TestCase):
|
||||
)
|
||||
def test_pipeline_throughput_optimized_matmul_reducescatter(
|
||||
self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
input_dtype = jnp.float32
|
||||
num_devices = jax.device_count()
|
||||
|
||||
@ -1325,6 +1333,8 @@ class PallasCallMegacoreTest(parameterized.TestCase):
|
||||
super().setUp()
|
||||
|
||||
def test_can_partition_nondivisible_grid_with_dynamic_dimensions(self):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
|
||||
def mul_pipeline(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...] * 2
|
||||
@ -1359,6 +1369,8 @@ class PallasCallMegacoreTest(parameterized.TestCase):
|
||||
np.testing.assert_allclose(func(jnp.array([5]), x), x * 2)
|
||||
|
||||
def test_megacore_mul(self):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
x = jax.random.uniform(jax.random.key(0), (512, 512))
|
||||
|
||||
def matmul_pipeline(x_ref, y_ref):
|
||||
@ -1396,6 +1408,8 @@ class PallasCallMegacoreTest(parameterized.TestCase):
|
||||
(768, 1024, 768, 256, 512, 256),
|
||||
)
|
||||
def test_megacore_matmul(self, m, k, n, bm, bk, bn):
|
||||
# TODO(b/358121809): Re-enable this test once the bug is fixed.
|
||||
self.skipTest('Broken test.')
|
||||
k1, k2 = jax.random.split(jax.random.key(42))
|
||||
x = jax.random.uniform(k1, (m, k))
|
||||
y = jax.random.uniform(k2, (k, n))
|
||||
|
Loading…
x
Reference in New Issue
Block a user