[Pallas] Fix lowering tests for reduction ops

Remove unnecessary skip statements. Also added tests for bf16 types.

PiperOrigin-RevId: 694130207
This commit is contained in:
Ayaka 2024-11-07 08:36:44 -08:00 committed by jax authors
parent de06584d98
commit 1a544b6f36

View File

@ -1793,6 +1793,7 @@ class OpsTest(PallasBaseTest):
for axis in [0, 1, (1,), (0, 1)]
for dtype in [
"float16",
"bfloat16",
"float32",
"float64",
"int32",
@ -1800,28 +1801,29 @@ class OpsTest(PallasBaseTest):
"uint32",
"uint64",
]
if isinstance(axis, int) or "arg" not in op_name
])
def test_array_reduce(self, op, dtype, axis):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")
if not isinstance(axis, int):
self.skipTest("TODO: tuple axes are not yet supported")
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")
if jtu.test_device_matches(["tpu"]):
self.skipTest("Unimplemented primitive: broadcast_to")
if jtu.test_device_matches(["tpu"]) and dtype == "float16":
self.skipTest("float16 is not supported on TPU")
# Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
# `index_type` to be i32
if (
jax.config.x64_enabled
and jtu.test_device_matches(["gpu"])
and op in {jnp.argmin, jnp.argmax}
and op in (jnp.argmin, jnp.argmax)
):
self.skipTest("Not supported on GPU in 64-bit mode")
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")
m, n = 32, 8
def make_x(key):