diff --git a/jax/_src/array.py b/jax/_src/array.py index 1261d6f71..d4cf87ad3 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections import defaultdict -import enum import math import operator as op import numpy as np @@ -50,14 +49,6 @@ Device = xc.Device Index = tuple[slice, ...] PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this. - -# Mirror of dlpack.h enum -class DLDeviceType(enum.IntEnum): - kDLCPU = 1 - kDLCUDA = 2 - kDLROCM = 10 - - class Shard: """A single data shard of an Array. @@ -386,10 +377,12 @@ class ArrayImpl(basearray.Array): from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top return to_dlpack(self, stream=stream) - def __dlpack_device__(self) -> tuple[DLDeviceType, int]: + def __dlpack_device__(self) -> tuple[int, int]: if len(self._arrays) != 1: raise ValueError("__dlpack__ only supported for unsharded arrays.") + from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top + if self.platform() == "cpu": return DLDeviceType.kDLCPU, 0 diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index eff01b8fc..2e2705e23 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -14,6 +14,8 @@ from __future__ import annotations +import enum + from jax import numpy as jnp from jax._src import array from jax._src import xla_bridge @@ -28,6 +30,13 @@ SUPPORTED_DTYPES = frozenset({ jnp.float64, jnp.complex64, jnp.complex128}) +# Mirror of dlpack.h enum +class DLDeviceType(enum.IntEnum): + kDLCPU = 1 + kDLCUDA = 2 + kDLROCM = 10 + + def to_dlpack(x: Array, take_ownership: bool = False, stream: int | None = None): """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``. @@ -63,26 +72,61 @@ def to_dlpack(x: Array, take_ownership: bool = False, -def from_dlpack(dlpack): +def from_dlpack(external_array): """Returns a :class:`~jax.Array` representation of a DLPack tensor. - The returned :class:`~jax.Array` shares memory with ``dlpack``. + The returned :class:`~jax.Array` shares memory with ``external_array``. Args: - dlpack: a DLPack tensor, on either CPU or GPU. - """ - cpu_backend = xla_bridge.get_backend("cpu") - try: - gpu_backend = xla_bridge.get_backend("cuda") - except RuntimeError: - gpu_backend = None + external_array: an array object that has __dlpack__ and __dlpack_device__ + methods, or a DLPack tensor on either CPU or GPU (legacy API). - # Try ROCm if CUDA backend not found - if gpu_backend is None: + Returns: + A jax.Array + """ + if hasattr(external_array, "__dlpack__") and xla_extension_version >= 191: + dl_device_type, device_id = external_array.__dlpack_device__() try: - gpu_backend = xla_bridge.get_backend("rocm") + device_platform = { + DLDeviceType.kDLCPU: "cpu", + DLDeviceType.kDLCUDA: "cuda", + DLDeviceType.kDLROCM: "rocm", + }[dl_device_type] + except TypeError: + # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using + # TypeError. + raise TypeError( + "Array passed to from_dlpack is on unsupported device type " + f"(DLDeviceType: {dl_device_type}, array: {external_array}") + + backend = xla_bridge.get_backend(device_platform) + device = backend.device_from_local_hardware_id(device_id) + try: + stream = device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: # type: ignore + if "UNIMPLEMENTED" in str(err): + stream = None + else: + raise + dlpack = external_array.__dlpack__(stream) + + return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, device, stream)) + else: + # Legacy path + dlpack = external_array + cpu_backend = xla_bridge.get_backend("cpu") + try: + gpu_backend = xla_bridge.get_backend("cuda") except RuntimeError: gpu_backend = None - return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend)) + # Try ROCm if CUDA backend not found + if gpu_backend is None: + try: + gpu_backend = xla_bridge.get_backend("rocm") + except RuntimeError: + gpu_backend = None + + return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, cpu_backend, gpu_backend)) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index d761d5be7..fb12756f2 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -21,6 +21,7 @@ from jax import config import jax.dlpack import jax.numpy as jnp from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version import numpy as np @@ -87,6 +88,30 @@ class DLPackTest(jtu.JaxTestCase): "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack)) + @jtu.sample_product( + shape=all_shapes, + dtype=dlpack_dtypes, + gpu=[False, True], + ) + def testJaxArrayRoundTrip(self, shape, dtype, gpu): + if xla_extension_version < 191: + self.skipTest("Need xla_extension_version >= 191") + + rng = jtu.rand_default(self.rng()) + np = rng(shape, dtype) + if gpu and jax.default_backend() == "cpu": + raise unittest.SkipTest("Skipping GPU test case on CPU") + device = jax.devices("gpu" if gpu else "cpu")[0] + x = jax.device_put(np, device) + y = jax.dlpack.from_dlpack(x) + self.assertEqual(y.device(), device) + self.assertAllClose(np.astype(x.dtype), y) + # Test we can create multiple arrays + z = jax.dlpack.from_dlpack(x) + self.assertEqual(z.device(), device) + self.assertAllClose(np.astype(x.dtype), z) + + @jtu.sample_product( shape=all_shapes, dtype=dlpack_dtypes, diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 7c21c1381..eecd6132e 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -137,6 +137,30 @@ class DLPackTest(jtu.JaxTestCase): z = jax.jit(lambda x: x + 1)(y) self.assertAllClose(x_np + dtype(1), z) + @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + def testTorchToJaxArray(self, shape, dtype): + if xla_extension_version < 191: + self.skipTest("Need xla_extension_version >= 191") + + if not config.x64_enabled and dtype in [jnp.int64, jnp.float64, + jnp.complex128]: + self.skipTest("x64 types are disabled by jax_enable_x64") + + rng = jtu.rand_default(self.rng()) + x_np = rng(shape, dtype) + if dtype == jnp.bfloat16: + x = torch.tensor(x_np.view(jnp.int16)).view(torch.bfloat16) + else: + x = torch.tensor(x_np) + x = x.cuda() if jtu.device_under_test() == "gpu" else x + x = x.contiguous() + y = jax.dlpack.from_dlpack(x) + self.assertAllClose(x_np, y) + + # Verify the resulting value can be passed to a jit computation. + z = jax.jit(lambda x: x + 1)(y) + self.assertAllClose(x_np + dtype(1), z) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())