[Pallas] Add scalar f32 downcast test cases.

PiperOrigin-RevId: 678779025
This commit is contained in:
jax authors 2024-09-25 11:24:58 -07:00
parent e05c37c667
commit 70346bda74

View File

@ -606,6 +606,30 @@ class OpsTest(PallasBaseTest):
y, y_ref = y.astype(np.float32), y_ref.astype(np.float32)
np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.)
@parameterized.parameters(
jnp.bfloat16,
jnp.float8_e5m2,
jnp.float8_e4m3fn,
)
@jtu.skip_on_devices("gpu")
def test_scalar_downcast_float32(self, dtype):
def kernel(x_ref, o_ref):
o_ref[0, 0] = x_ref[:][0, 0].astype(dtype)
x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
result = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec((8, 128), lambda *_: (0, 0)),
],
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
out_shape=jax.ShapeDtypeStruct([1, 1], dtype),
grid=(1,),
)(x)
np.testing.assert_array_equal(result[0, 0], x[0, 0].astype(dtype))
@parameterized.product(
shape=((64,), (8, 8)),
dtype=(jnp.int32, jnp.int16, jnp.int8),