[dlpack] Support more DLPack dtypes now that we target DLPack 1.1

I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor
TensorFlow currently support importing DLPack arrays with the new dtypes.

PiperOrigin-RevId: 736882904
This commit is contained in:
Sergei Lebedev 2025-03-14 09:09:58 -07:00 committed by jax authors
parent c9ac82c826
commit 97bbc37e83

View File

@ -22,6 +22,7 @@ import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import version as jaxlib_version
import numpy as np
@ -42,6 +43,27 @@ except ImportError:
dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__)
# These dtypes are not supported by neither NumPy nor TensorFlow, therefore
# we list them separately from ``jax.dlpack.SUPPORTED_DTYPES``.
extra_dlpack_dtypes = []
if jaxlib_version >= (0, 5, 3):
extra_dlpack_dtypes = [
jnp.float8_e4m3b11fnuz,
jnp.float8_e4m3fn,
jnp.float8_e4m3fnuz,
jnp.float8_e5m2,
jnp.float8_e5m2fnuz,
] + [
dtype
for name in [
"float4_e2m1fn",
"float8_e3m4",
"float8_e4m3",
"float8_e8m0fnu",
]
if (dtype := getattr(jnp, name, None))
]
numpy_dtypes = sorted(
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
key=lambda x: x.__name__)
@ -63,14 +85,16 @@ class DLPackTest(jtu.JaxTestCase):
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
copy=[False, True, None],
use_stream=[False, True],
shape=all_shapes,
dtype=dlpack_dtypes + extra_dlpack_dtypes,
copy=[False, True, None],
use_stream=[False, True],
)
@jtu.run_on_devices("gpu")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
@jtu.ignore_warning(
message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning,
)
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)