mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas] Fix lowering tests for reduction ops
Remove unnecessary skip statements. Also added tests for bf16 types. PiperOrigin-RevId: 694130207
This commit is contained in:
parent
de06584d98
commit
1a544b6f36
@ -1793,6 +1793,7 @@ class OpsTest(PallasBaseTest):
|
||||
for axis in [0, 1, (1,), (0, 1)]
|
||||
for dtype in [
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"float32",
|
||||
"float64",
|
||||
"int32",
|
||||
@ -1800,28 +1801,29 @@ class OpsTest(PallasBaseTest):
|
||||
"uint32",
|
||||
"uint64",
|
||||
]
|
||||
if isinstance(axis, int) or "arg" not in op_name
|
||||
])
|
||||
def test_array_reduce(self, op, dtype, axis):
|
||||
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
|
||||
self.skipTest("16-bit types are not supported on TPU")
|
||||
if not isinstance(axis, int):
|
||||
self.skipTest("TODO: tuple axes are not yet supported")
|
||||
|
||||
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
|
||||
self.skipTest("64-bit types require x64_enabled")
|
||||
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Unimplemented primitive: broadcast_to")
|
||||
|
||||
if jtu.test_device_matches(["tpu"]) and dtype == "float16":
|
||||
self.skipTest("float16 is not supported on TPU")
|
||||
|
||||
# Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
|
||||
# `index_type` to be i32
|
||||
if (
|
||||
jax.config.x64_enabled
|
||||
and jtu.test_device_matches(["gpu"])
|
||||
and op in {jnp.argmin, jnp.argmax}
|
||||
and op in (jnp.argmin, jnp.argmax)
|
||||
):
|
||||
self.skipTest("Not supported on GPU in 64-bit mode")
|
||||
|
||||
# The Pallas TPU lowering currently supports only blocks of rank >= 1
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Not supported on TPU")
|
||||
|
||||
m, n = 32, 8
|
||||
|
||||
def make_x(key):
|
||||
|
Loading…
x
Reference in New Issue
Block a user