mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas] Add scalar f32 downcast test cases.
PiperOrigin-RevId: 678779025
This commit is contained in:
parent
e05c37c667
commit
70346bda74
@ -606,6 +606,30 @@ class OpsTest(PallasBaseTest):
|
||||
y, y_ref = y.astype(np.float32), y_ref.astype(np.float32)
|
||||
np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.)
|
||||
|
||||
@parameterized.parameters(
|
||||
jnp.bfloat16,
|
||||
jnp.float8_e5m2,
|
||||
jnp.float8_e4m3fn,
|
||||
)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_scalar_downcast_float32(self, dtype):
|
||||
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[0, 0] = x_ref[:][0, 0].astype(dtype)
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
|
||||
result = self.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
pl.BlockSpec((8, 128), lambda *_: (0, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
|
||||
out_shape=jax.ShapeDtypeStruct([1, 1], dtype),
|
||||
grid=(1,),
|
||||
)(x)
|
||||
|
||||
np.testing.assert_array_equal(result[0, 0], x[0, 0].astype(dtype))
|
||||
|
||||
@parameterized.product(
|
||||
shape=((64,), (8, 8)),
|
||||
dtype=(jnp.int32, jnp.int16, jnp.int8),
|
||||
|
Loading…
x
Reference in New Issue
Block a user