[Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only

PiperOrigin-RevId: 691477248
This commit is contained in:
jax authors 2024-10-30 10:48:09 -07:00
parent 99ea4c1a4a
commit 3904ced255

View File

@ -1896,6 +1896,28 @@ class OpsTest(PallasBaseTest):
y_ref = jnp.cumsum(x, axis=axis)
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)
@parameterized.parameters(
(0, jnp.float32),
(0, jnp.bfloat16),
(1, jnp.float32),
(1, jnp.bfloat16),
(-1, jnp.float32),
(-1, jnp.bfloat16),
)
def test_triu(self, k, dtype):
if dtype == jnp.bfloat16 and jtu.test_device_matches(["tpu"]):
# TODO(mvoz): b/376330700
raise unittest.SkipTest('NYI - bf16 select')
x = jnp.arange(128 * 256, dtype=dtype).reshape((128, 256))
def kernel(x_ref, out_ref):
out_ref[...] = jnp.triu(x_ref[...], k=k)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((128, 256), dtype)
)(x)
np.testing.assert_array_equal(out, np.triu(x, k=k))
class OpsInterpretTest(OpsTest):
INTERPRET = True