diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index cce40788d..80a4d8ef5 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import version as jaxlib_version import numpy as np @@ -42,6 +43,27 @@ except ImportError: dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__) +# These dtypes are not supported by neither NumPy nor TensorFlow, therefore +# we list them separately from ``jax.dlpack.SUPPORTED_DTYPES``. +extra_dlpack_dtypes = [] +if jaxlib_version >= (0, 5, 3): + extra_dlpack_dtypes = [ + jnp.float8_e4m3b11fnuz, + jnp.float8_e4m3fn, + jnp.float8_e4m3fnuz, + jnp.float8_e5m2, + jnp.float8_e5m2fnuz, + ] + [ + dtype + for name in [ + "float4_e2m1fn", + "float8_e3m4", + "float8_e4m3", + "float8_e8m0fnu", + ] + if (dtype := getattr(jnp, name, None)) + ] + numpy_dtypes = sorted( [dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16], key=lambda x: x.__name__) @@ -63,14 +85,16 @@ class DLPackTest(jtu.JaxTestCase): self.skipTest(f"DLPack not supported on {jtu.device_under_test()}") @jtu.sample_product( - shape=all_shapes, - dtype=dlpack_dtypes, - copy=[False, True, None], - use_stream=[False, True], + shape=all_shapes, + dtype=dlpack_dtypes + extra_dlpack_dtypes, + copy=[False, True, None], + use_stream=[False, True], ) @jtu.run_on_devices("gpu") - @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning) + @jtu.ignore_warning( + message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning, + ) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype)