mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType. Add a static_assert to enforce equality. PiperOrigin-RevId: 576288042
This commit is contained in:
parent
4897a5fb5a
commit
47a76df7cc
@ -10,6 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jaxlib 0.4.20
|
||||
|
||||
* Bug fixes
|
||||
* Fixed some type confusion between E4M3 and E5M2 float8 types.
|
||||
|
||||
## jax 0.4.19 (Oct 19, 2023)
|
||||
|
||||
* New Features
|
||||
|
@ -252,6 +252,7 @@ jax_test(
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
"//jax:internal_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from jax._src import op_shardings
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
|
||||
pmap_sharding_devices_indices_map)
|
||||
@ -772,6 +773,13 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, arr_copy * 2)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
|
||||
@unittest.skipIf(xla_extension_version < 208, "Test requires jaxlib > 0.4.19")
|
||||
def test_shards_have_correct_dtype(self, dtype):
|
||||
x = jnp.ones((), dtype=dtype)
|
||||
for shard in x.addressable_shards:
|
||||
self.assertEqual(shard.data.dtype, dtype)
|
||||
|
||||
|
||||
class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user