From 1a544b6f363fbb03edc40e03d759cd42a6b64733 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 7 Nov 2024 08:36:44 -0800 Subject: [PATCH] [Pallas] Fix lowering tests for reduction ops Remove unnecessary skip statements. Also added tests for bf16 types. PiperOrigin-RevId: 694130207 --- tests/pallas/ops_test.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 318df0b0b..58d353677 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1793,6 +1793,7 @@ class OpsTest(PallasBaseTest): for axis in [0, 1, (1,), (0, 1)] for dtype in [ "float16", + "bfloat16", "float32", "float64", "int32", @@ -1800,28 +1801,29 @@ class OpsTest(PallasBaseTest): "uint32", "uint64", ] - if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): - if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: - self.skipTest("16-bit types are not supported on TPU") + if not isinstance(axis, int): + self.skipTest("TODO: tuple axes are not yet supported") if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") + if jtu.test_device_matches(["tpu"]): + self.skipTest("Unimplemented primitive: broadcast_to") + + if jtu.test_device_matches(["tpu"]) and dtype == "float16": + self.skipTest("float16 is not supported on TPU") + # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects # `index_type` to be i32 if ( jax.config.x64_enabled and jtu.test_device_matches(["gpu"]) - and op in {jnp.argmin, jnp.argmax} + and op in (jnp.argmin, jnp.argmax) ): self.skipTest("Not supported on GPU in 64-bit mode") - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - m, n = 32, 8 def make_x(key):