mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[dlpack] Support more DLPack dtypes now that we target DLPack 1.1
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor TensorFlow currently support importing DLPack arrays with the new dtypes. PiperOrigin-RevId: 736882904
This commit is contained in:
parent
c9ac82c826
commit
97bbc37e83
@ -22,6 +22,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -42,6 +43,27 @@ except ImportError:
|
||||
|
||||
dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__)
|
||||
|
||||
# These dtypes are not supported by neither NumPy nor TensorFlow, therefore
|
||||
# we list them separately from ``jax.dlpack.SUPPORTED_DTYPES``.
|
||||
extra_dlpack_dtypes = []
|
||||
if jaxlib_version >= (0, 5, 3):
|
||||
extra_dlpack_dtypes = [
|
||||
jnp.float8_e4m3b11fnuz,
|
||||
jnp.float8_e4m3fn,
|
||||
jnp.float8_e4m3fnuz,
|
||||
jnp.float8_e5m2,
|
||||
jnp.float8_e5m2fnuz,
|
||||
] + [
|
||||
dtype
|
||||
for name in [
|
||||
"float4_e2m1fn",
|
||||
"float8_e3m4",
|
||||
"float8_e4m3",
|
||||
"float8_e8m0fnu",
|
||||
]
|
||||
if (dtype := getattr(jnp, name, None))
|
||||
]
|
||||
|
||||
numpy_dtypes = sorted(
|
||||
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
|
||||
key=lambda x: x.__name__)
|
||||
@ -63,14 +85,16 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
copy=[False, True, None],
|
||||
use_stream=[False, True],
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes + extra_dlpack_dtypes,
|
||||
copy=[False, True, None],
|
||||
use_stream=[False, True],
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
@jtu.ignore_warning(
|
||||
message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user