mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00

I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor TensorFlow currently support importing DLPack arrays with the new dtypes. PiperOrigin-RevId: 736882904
416 lines
14 KiB
Python
416 lines
14 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
|
|
import jax
|
|
import jax.dlpack
|
|
import jax.numpy as jnp
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib import version as jaxlib_version
|
|
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
try:
|
|
import cupy
|
|
except ImportError:
|
|
cupy = None
|
|
|
|
try:
|
|
import tensorflow as tf
|
|
tf_version = tuple(
|
|
int(x) for x in tf.version.VERSION.split("-")[0].split("."))
|
|
except ImportError:
|
|
tf = None
|
|
|
|
|
|
dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__)
|
|
|
|
# These dtypes are not supported by neither NumPy nor TensorFlow, therefore
|
|
# we list them separately from ``jax.dlpack.SUPPORTED_DTYPES``.
|
|
extra_dlpack_dtypes = []
|
|
if jaxlib_version >= (0, 5, 3):
|
|
extra_dlpack_dtypes = [
|
|
jnp.float8_e4m3b11fnuz,
|
|
jnp.float8_e4m3fn,
|
|
jnp.float8_e4m3fnuz,
|
|
jnp.float8_e5m2,
|
|
jnp.float8_e5m2fnuz,
|
|
] + [
|
|
dtype
|
|
for name in [
|
|
"float4_e2m1fn",
|
|
"float8_e3m4",
|
|
"float8_e4m3",
|
|
"float8_e8m0fnu",
|
|
]
|
|
if (dtype := getattr(jnp, name, None))
|
|
]
|
|
|
|
numpy_dtypes = sorted(
|
|
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
|
|
key=lambda x: x.__name__)
|
|
|
|
cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]
|
|
|
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
|
empty_array_shapes = []
|
|
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
|
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
|
|
|
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
|
all_shapes = nonempty_array_shapes + empty_array_shapes
|
|
|
|
class DLPackTest(jtu.JaxTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
if not jtu.test_device_matches(["cpu", "gpu"]):
|
|
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=dlpack_dtypes + extra_dlpack_dtypes,
|
|
copy=[False, True, None],
|
|
use_stream=[False, True],
|
|
)
|
|
@jtu.run_on_devices("gpu")
|
|
@jtu.ignore_warning(
|
|
message="Calling from_dlpack with a DLPack tensor",
|
|
category=DeprecationWarning,
|
|
)
|
|
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
|
|
def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
|
|
copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
|
|
assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy"
|
|
|
|
# Check if the source device is preserved
|
|
x = jax.device_put(np, jax.devices("cpu")[0])
|
|
device = jax.devices("gpu")[0]
|
|
y = jax.device_put(x, device)
|
|
dl_device = y.__dlpack_device__()
|
|
if use_stream:
|
|
stream = tuple(y.devices())[0].get_stream_for_external_ready_events()
|
|
dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream)
|
|
else:
|
|
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
|
|
z = jax.dlpack.from_dlpack(dlpack)
|
|
|
|
self.assertEqual(z.devices(), {device})
|
|
self.assertAllClose(np.astype(x.dtype), z)
|
|
self.assertRaisesRegex(RuntimeError,
|
|
"DLPack tensor may be consumed at most once",
|
|
lambda: jax.dlpack.from_dlpack(dlpack))
|
|
|
|
if shape in nonempty_array_shapes:
|
|
_check_copy(y, z, bool(copy))
|
|
|
|
# Check if the destination device can be specified
|
|
make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
|
|
if copy == False:
|
|
self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
|
|
return
|
|
|
|
z = jax.dlpack.from_dlpack(make_dlpack())
|
|
self.assertEqual(z.devices(), {device})
|
|
self.assertAllClose(x, z)
|
|
|
|
if shape in nonempty_array_shapes:
|
|
_check_copy(x, z, True)
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=dlpack_dtypes,
|
|
gpu=[False, True],
|
|
)
|
|
def testJaxArrayRoundTrip(self, shape, dtype, gpu):
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
if gpu and jax.default_backend() == "cpu":
|
|
raise unittest.SkipTest("Skipping GPU test case on CPU")
|
|
device = jax.devices("gpu" if gpu else "cpu")[0]
|
|
x = jax.device_put(np, device)
|
|
y = jax.dlpack.from_dlpack(x)
|
|
self.assertEqual(y.devices(), {device})
|
|
self.assertAllClose(np.astype(x.dtype), y)
|
|
# Test we can create multiple arrays
|
|
z = jax.dlpack.from_dlpack(x)
|
|
self.assertEqual(z.devices(), {device})
|
|
self.assertAllClose(np.astype(x.dtype), z)
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=dlpack_dtypes,
|
|
)
|
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
|
category=DeprecationWarning)
|
|
def testTensorFlowToJax(self, shape, dtype):
|
|
if (not config.enable_x64.value and
|
|
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
|
raise self.skipTest("x64 types are disabled by jax_enable_x64")
|
|
if (jtu.test_device_matches(["gpu"]) and
|
|
not tf.config.list_physical_devices("GPU")):
|
|
raise self.skipTest("TensorFlow not configured with GPU support")
|
|
|
|
if jtu.test_device_matches(["gpu"]) and dtype == jnp.int32:
|
|
raise self.skipTest("TensorFlow does not place int32 tensors on GPU")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
with tf.device("/GPU:0" if jtu.test_device_matches(["gpu"]) else "/CPU:0"):
|
|
x = tf.identity(tf.constant(np))
|
|
dlpack = tf.experimental.dlpack.to_dlpack(x)
|
|
y = jax.dlpack.from_dlpack(dlpack)
|
|
self.assertAllClose(np, y)
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=dlpack_dtypes,
|
|
)
|
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
|
def testJaxToTensorFlow(self, shape, dtype):
|
|
if (not config.enable_x64.value and
|
|
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
|
self.skipTest("x64 types are disabled by jax_enable_x64")
|
|
if (jtu.test_device_matches(["gpu"]) and
|
|
not tf.config.list_physical_devices("GPU")):
|
|
raise self.skipTest("TensorFlow not configured with GPU support")
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
x = jnp.array(np)
|
|
# TODO(b/171320191): this line works around a missing context initialization
|
|
# bug in TensorFlow.
|
|
_ = tf.add(1, 1)
|
|
dlpack = jax.dlpack.to_dlpack(x)
|
|
y = tf.experimental.dlpack.from_dlpack(dlpack)
|
|
self.assertAllClose(np, y.numpy())
|
|
|
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
|
category=DeprecationWarning)
|
|
def testTensorFlowToJaxInt64(self):
|
|
# See https://github.com/jax-ml/jax/issues/11895
|
|
x = jax.dlpack.from_dlpack(
|
|
tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64)))
|
|
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
|
|
self.assertEqual(x.dtype, dtype_expected)
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=numpy_dtypes,
|
|
copy=[False, True],
|
|
)
|
|
def testNumpyToJax(self, shape, dtype, copy):
|
|
rng = jtu.rand_default(self.rng())
|
|
x_np = rng(shape, dtype)
|
|
device = jax.devices()[0]
|
|
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
|
|
if jax.default_backend() == 'gpu' and not copy:
|
|
self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Specified .* which requires a copy",
|
|
_from_dlpack
|
|
)
|
|
else:
|
|
self.assertAllClose(x_np, _from_dlpack())
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=numpy_dtypes,
|
|
)
|
|
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
|
|
def testJaxToNumpy(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
x_jax = jnp.array(rng(shape, dtype))
|
|
x_np = np.from_dlpack(x_jax)
|
|
self.assertAllClose(x_np, x_jax)
|
|
|
|
@jtu.ignore_warning(message="Calling from_dlpack.*",
|
|
category=DeprecationWarning)
|
|
def testNondefaultLayout(self):
|
|
# Generate numpy array with nonstandard layout
|
|
a = np.arange(4).reshape(2, 2)
|
|
b = a.T
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"from_dlpack got array with non-default layout with minor-to-major "
|
|
r"dimensions \(0,1\), expected \(1,0\)"):
|
|
b_jax = jax.dlpack.from_dlpack(b.__dlpack__())
|
|
|
|
|
|
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
|
|
|
@jtu.skip_on_devices("cuda")
|
|
def testCudaArrayInterfaceOnNonCudaFails(self):
|
|
x = jnp.arange(5)
|
|
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
|
|
with self.assertRaisesRegex(
|
|
AttributeError,
|
|
"__cuda_array_interface__ is only defined for NVidia GPU buffers.",
|
|
):
|
|
_ = x.__cuda_array_interface__
|
|
|
|
@jtu.run_on_devices("cuda")
|
|
def testCudaArrayInterfaceOnShardedArrayFails(self):
|
|
devices = jax.local_devices()
|
|
if len(devices) <= 1:
|
|
raise unittest.SkipTest("Test requires 2 or more devices")
|
|
mesh = jax.sharding.Mesh(np.array(devices), ("x",))
|
|
sharding = jax.sharding.NamedSharding(mesh, P("x"))
|
|
x = jnp.arange(16)
|
|
x = jax.device_put(x, sharding)
|
|
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
|
|
with self.assertRaisesRegex(
|
|
AttributeError,
|
|
"__cuda_array_interface__ is only supported for unsharded arrays.",
|
|
):
|
|
_ = x.__cuda_array_interface__
|
|
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=cuda_array_interface_dtypes,
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def testCudaArrayInterfaceWorks(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
y = jnp.array(x)
|
|
z = np.asarray(y)
|
|
a = y.__cuda_array_interface__
|
|
self.assertEqual(shape, a["shape"])
|
|
self.assertEqual(z.__array_interface__["typestr"], a["typestr"])
|
|
|
|
@jtu.run_on_devices("cuda")
|
|
def testCudaArrayInterfaceBfloat16Fails(self):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng((2, 2), jnp.bfloat16)
|
|
y = jnp.array(x)
|
|
with self.assertRaisesRegex(AttributeError, ".*not supported for BF16.*"):
|
|
_ = y.__cuda_array_interface__
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=cuda_array_interface_dtypes,
|
|
)
|
|
@unittest.skipIf(not cupy, "Test requires CuPy")
|
|
@jtu.run_on_devices("cuda")
|
|
def testJaxToCuPy(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
y = jnp.array(x)
|
|
z = cupy.asarray(y)
|
|
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
|
z.__cuda_array_interface__["data"][0])
|
|
self.assertAllClose(x, cupy.asnumpy(z))
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=jtu.dtypes.supported(cuda_array_interface_dtypes),
|
|
)
|
|
@unittest.skipIf(not cupy, "Test requires CuPy")
|
|
@jtu.run_on_devices("cuda")
|
|
def testCuPyToJax(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
y = cupy.asarray(x)
|
|
z = jnp.array(y, copy=False) # this conversion uses dlpack protocol
|
|
self.assertEqual(z.dtype, dtype)
|
|
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
|
z.__cuda_array_interface__["data"][0])
|
|
self.assertAllClose(np.asarray(z), cupy.asnumpy(y))
|
|
|
|
@jtu.sample_product(
|
|
shape=all_shapes,
|
|
dtype=jtu.dtypes.supported(cuda_array_interface_dtypes),
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def testCaiToJax(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
|
|
# using device with highest device_id for testing the correctness
|
|
# of detecting the device id from a pointer value
|
|
device = jax.devices('cuda')[-1]
|
|
with jax.default_device(device):
|
|
y = jnp.array(x, dtype=dtype)
|
|
self.assertEqual(y.dtype, dtype)
|
|
|
|
# Using a jax array CAI provider support to construct an object
|
|
# that implements the CUDA Array Interface, versions 2 and 3.
|
|
cai = y.__cuda_array_interface__
|
|
stream = tuple(y.devices())[0].get_stream_for_external_ready_events()
|
|
|
|
class CAIWithoutStridesV2:
|
|
__cuda_array_interface__ = cai.copy()
|
|
__cuda_array_interface__["version"] = 2
|
|
# CAI version 2 may not define strides and does not define stream
|
|
__cuda_array_interface__.pop("strides", None)
|
|
__cuda_array_interface__.pop("stream", None)
|
|
|
|
class CAIWithoutStrides:
|
|
__cuda_array_interface__ = cai.copy()
|
|
__cuda_array_interface__["version"] = 3
|
|
__cuda_array_interface__["strides"] = None
|
|
__cuda_array_interface__["stream"] = None # default stream
|
|
|
|
class CAIWithStrides:
|
|
__cuda_array_interface__ = cai.copy()
|
|
__cuda_array_interface__["version"] = 3
|
|
strides = (dtype.dtype.itemsize,) if shape else ()
|
|
for s in reversed(shape[1:]):
|
|
strides = (strides[0] * s, *strides)
|
|
__cuda_array_interface__['strides'] = strides
|
|
__cuda_array_interface__["stream"] = stream
|
|
|
|
for CAIObject in [CAIWithoutStridesV2, CAIWithoutStrides,
|
|
CAIWithStrides]:
|
|
z = jnp.array(CAIObject(), copy=False)
|
|
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
|
z.__cuda_array_interface__["data"][0])
|
|
self.assertAllClose(x, z)
|
|
if 0 in shape:
|
|
# the device id detection from a zero pointer value is not
|
|
# possible
|
|
pass
|
|
else:
|
|
self.assertEqual(y.devices(), z.devices())
|
|
|
|
z = jnp.array(CAIObject(), copy=True)
|
|
if 0 not in shape:
|
|
self.assertNotEqual(y.__cuda_array_interface__["data"][0],
|
|
z.__cuda_array_interface__["data"][0])
|
|
self.assertAllClose(x, z)
|
|
|
|
class Bfloat16Test(jtu.JaxTestCase):
|
|
|
|
@unittest.skipIf((not tf or tf_version < (2, 5, 0)),
|
|
"Test requires TensorFlow 2.5.0 or newer")
|
|
def testJaxAndTfHaveTheSameBfloat16Type(self):
|
|
self.assertEqual(np.dtype(jnp.bfloat16).num,
|
|
np.dtype(tf.dtypes.bfloat16.as_numpy_dtype).num)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|