mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only
PiperOrigin-RevId: 691477248
This commit is contained in:
parent
99ea4c1a4a
commit
3904ced255
@ -1896,6 +1896,28 @@ class OpsTest(PallasBaseTest):
|
||||
y_ref = jnp.cumsum(x, axis=axis)
|
||||
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)
|
||||
|
||||
@parameterized.parameters(
|
||||
(0, jnp.float32),
|
||||
(0, jnp.bfloat16),
|
||||
(1, jnp.float32),
|
||||
(1, jnp.bfloat16),
|
||||
(-1, jnp.float32),
|
||||
(-1, jnp.bfloat16),
|
||||
)
|
||||
def test_triu(self, k, dtype):
|
||||
if dtype == jnp.bfloat16 and jtu.test_device_matches(["tpu"]):
|
||||
# TODO(mvoz): b/376330700
|
||||
raise unittest.SkipTest('NYI - bf16 select')
|
||||
|
||||
x = jnp.arange(128 * 256, dtype=dtype).reshape((128, 256))
|
||||
|
||||
def kernel(x_ref, out_ref):
|
||||
out_ref[...] = jnp.triu(x_ref[...], k=k)
|
||||
|
||||
out = self.pallas_call(
|
||||
kernel, out_shape=jax.ShapeDtypeStruct((128, 256), dtype)
|
||||
)(x)
|
||||
np.testing.assert_array_equal(out, np.triu(x, k=k))
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user