1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Disable a pallas export compatibility test that fails on TPU v6e.

PiperOrigin-RevId: 673295487
This commit is contained in:
Peter Hawkins 2024-09-11 02:00:02 -07:00 committed by jax authors
parent 808003b4e2
commit 49dd6ed8d8

@ -62,6 +62,8 @@ class CompatTest(bctu.CompatTestBase):
@jax.default_matmul_precision("bfloat16")
def test_mosaic_matmul(self):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(apaszke): Test fails on TPU v6e")
dtype = jnp.float32
def func():
# Build the inputs here, to reduce the size of the golden inputs.