[IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.

These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
This commit is contained in:
Peter Hawkins 2023-10-24 14:37:28 -07:00 committed by jax authors
parent 4897a5fb5a
commit 47a76df7cc
3 changed files with 12 additions and 0 deletions

View File

@ -10,6 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
## jaxlib 0.4.20
* Bug fixes
* Fixed some type confusion between E4M3 and E5M2 float8 types.
## jax 0.4.19 (Oct 19, 2023)
* New Features

View File

@ -252,6 +252,7 @@ jax_test(
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
"//jax:internal_test_util",
],
)

View File

@ -30,6 +30,7 @@ from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
pmap_sharding_devices_indices_map)
@ -772,6 +773,13 @@ class JaxArrayTest(jtu.JaxTestCase):
self.assertArraysEqual(out, arr_copy * 2)
self.assertTrue(arr.is_deleted())
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
@unittest.skipIf(xla_extension_version < 208, "Test requires jaxlib > 0.4.19")
def test_shards_have_correct_dtype(self, dtype):
x = jnp.ones((), dtype=dtype)
for shard in x.addressable_shards:
self.assertEqual(shard.data.dtype, dtype)
class ShardingTest(jtu.JaxTestCase):