From 47a76df7ccc8eb209a6a2e2b8190b4cf69527250 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Oct 2023 14:37:28 -0700 Subject: [PATCH] [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types. These types didn't match between xla::PrimitiveType and ifrt::DType. Add a static_assert to enforce equality. PiperOrigin-RevId: 576288042 --- CHANGELOG.md | 3 +++ tests/BUILD | 1 + tests/array_test.py | 8 ++++++++ 3 files changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b91c7f3a0..9a03dd56b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.20 +* Bug fixes + * Fixed some type confusion between E4M3 and E5M2 float8 types. + ## jax 0.4.19 (Oct 19, 2023) * New Features diff --git a/tests/BUILD b/tests/BUILD index 4cbdb3bcc..511536ee6 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -252,6 +252,7 @@ jax_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", + "//jax:internal_test_util", ], ) diff --git a/tests/array_test.py b/tests/array_test.py index c1cec797f..dee62266b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -30,6 +30,7 @@ from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.util import safe_zip from jax._src.sharding_impls import (_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map) @@ -772,6 +773,13 @@ class JaxArrayTest(jtu.JaxTestCase): self.assertArraysEqual(out, arr_copy * 2) self.assertTrue(arr.is_deleted()) + @parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats) + @unittest.skipIf(xla_extension_version < 208, "Test requires jaxlib > 0.4.19") + def test_shards_have_correct_dtype(self, dtype): + x = jnp.ones((), dtype=dtype) + for shard in x.addressable_shards: + self.assertEqual(shard.data.dtype, dtype) + class ShardingTest(jtu.JaxTestCase):