Temporarily disable broken tests in tpu_pallas_pipeline_test.py

PiperOrigin-RevId: 660972804
This commit is contained in:
Justin Fu 2024-08-08 14:03:21 -07:00 committed by jax authors
parent 12a9c8cfd4
commit deefbdd626

View File

@ -139,6 +139,8 @@ class PallasCallPipelineTest(parameterized.TestCase):
('hbm', pltpu.TPUMemorySpace.ANY),
)
def test_pipeline_matmul(self, memory_space):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.uniform(k1, (512, 512))
y = jax.random.uniform(k2, (512, 512))
@ -184,6 +186,8 @@ class PallasCallPipelineTest(parameterized.TestCase):
('hbm', pltpu.TPUMemorySpace.ANY),
)
def test_double_pipeline_matmul(self, memory_space):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.uniform(k1, (512, 512))
y = jax.random.uniform(k2, (512, 512))
@ -535,6 +539,8 @@ class PallasCallCollectivePipelineTest(parameterized.TestCase):
)
def test_pipeline_throughput_optimized_allgather_matmul(
self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
input_dtype = out_dtype
num_devices = jax.local_device_count()
@ -1065,6 +1071,8 @@ class PallasCallCollectivePipelineTest(parameterized.TestCase):
)
def test_pipeline_throughput_optimized_matmul_reducescatter(
self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
input_dtype = jnp.float32
num_devices = jax.device_count()
@ -1325,6 +1333,8 @@ class PallasCallMegacoreTest(parameterized.TestCase):
super().setUp()
def test_can_partition_nondivisible_grid_with_dynamic_dimensions(self):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
def mul_pipeline(x_ref, y_ref):
y_ref[...] = x_ref[...] * 2
@ -1359,6 +1369,8 @@ class PallasCallMegacoreTest(parameterized.TestCase):
np.testing.assert_allclose(func(jnp.array([5]), x), x * 2)
def test_megacore_mul(self):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
x = jax.random.uniform(jax.random.key(0), (512, 512))
def matmul_pipeline(x_ref, y_ref):
@ -1396,6 +1408,8 @@ class PallasCallMegacoreTest(parameterized.TestCase):
(768, 1024, 768, 256, 512, 256),
)
def test_megacore_matmul(self, m, k, n, bm, bk, bn):
# TODO(b/358121809): Re-enable this test once the bug is fixed.
self.skipTest('Broken test.')
k1, k2 = jax.random.split(jax.random.key(42))
x = jax.random.uniform(k1, (m, k))
y = jax.random.uniform(k2, (k, n))