mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Implement importing external dlpack-aware Python arrays.
See https://dmlc.github.io/dlpack/latest/python_spec.html.
This is the import path. The export path was implemented in
0b3cbfe4bc
.
This allows for creating jax.Arrays from external GPU arrays
asynchronously.
PiperOrigin-RevId: 561172624
This commit is contained in:
parent
e369445596
commit
ecee8f9116
@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import enum
|
||||
import math
|
||||
import operator as op
|
||||
import numpy as np
|
||||
@ -50,14 +49,6 @@ Device = xc.Device
|
||||
Index = tuple[slice, ...]
|
||||
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
|
||||
# Mirror of dlpack.h enum
|
||||
class DLDeviceType(enum.IntEnum):
|
||||
kDLCPU = 1
|
||||
kDLCUDA = 2
|
||||
kDLROCM = 10
|
||||
|
||||
|
||||
class Shard:
|
||||
"""A single data shard of an Array.
|
||||
|
||||
@ -386,10 +377,12 @@ class ArrayImpl(basearray.Array):
|
||||
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
||||
return to_dlpack(self, stream=stream)
|
||||
|
||||
def __dlpack_device__(self) -> tuple[DLDeviceType, int]:
|
||||
def __dlpack_device__(self) -> tuple[int, int]:
|
||||
if len(self._arrays) != 1:
|
||||
raise ValueError("__dlpack__ only supported for unsharded arrays.")
|
||||
|
||||
from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top
|
||||
|
||||
if self.platform() == "cpu":
|
||||
return DLDeviceType.kDLCPU, 0
|
||||
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
@ -28,6 +30,13 @@ SUPPORTED_DTYPES = frozenset({
|
||||
jnp.float64, jnp.complex64, jnp.complex128})
|
||||
|
||||
|
||||
# Mirror of dlpack.h enum
|
||||
class DLDeviceType(enum.IntEnum):
|
||||
kDLCPU = 1
|
||||
kDLCUDA = 2
|
||||
kDLROCM = 10
|
||||
|
||||
|
||||
def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
stream: int | None = None):
|
||||
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
||||
@ -63,14 +72,49 @@ def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
|
||||
|
||||
|
||||
def from_dlpack(dlpack):
|
||||
def from_dlpack(external_array):
|
||||
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
|
||||
|
||||
The returned :class:`~jax.Array` shares memory with ``dlpack``.
|
||||
The returned :class:`~jax.Array` shares memory with ``external_array``.
|
||||
|
||||
Args:
|
||||
dlpack: a DLPack tensor, on either CPU or GPU.
|
||||
external_array: an array object that has __dlpack__ and __dlpack_device__
|
||||
methods, or a DLPack tensor on either CPU or GPU (legacy API).
|
||||
|
||||
Returns:
|
||||
A jax.Array
|
||||
"""
|
||||
if hasattr(external_array, "__dlpack__") and xla_extension_version >= 191:
|
||||
dl_device_type, device_id = external_array.__dlpack_device__()
|
||||
try:
|
||||
device_platform = {
|
||||
DLDeviceType.kDLCPU: "cpu",
|
||||
DLDeviceType.kDLCUDA: "cuda",
|
||||
DLDeviceType.kDLROCM: "rocm",
|
||||
}[dl_device_type]
|
||||
except TypeError:
|
||||
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
|
||||
# TypeError.
|
||||
raise TypeError(
|
||||
"Array passed to from_dlpack is on unsupported device type "
|
||||
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
|
||||
|
||||
backend = xla_bridge.get_backend(device_platform)
|
||||
device = backend.device_from_local_hardware_id(device_id)
|
||||
try:
|
||||
stream = device.get_stream_for_external_ready_events()
|
||||
except xla_client.XlaRuntimeError as err: # type: ignore
|
||||
if "UNIMPLEMENTED" in str(err):
|
||||
stream = None
|
||||
else:
|
||||
raise
|
||||
dlpack = external_array.__dlpack__(stream)
|
||||
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, device, stream))
|
||||
else:
|
||||
# Legacy path
|
||||
dlpack = external_array
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("cuda")
|
||||
|
@ -21,6 +21,7 @@ from jax import config
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -87,6 +88,30 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
"DLPack tensor may be consumed at most once",
|
||||
lambda: jax.dlpack.from_dlpack(dlpack))
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
gpu=[False, True],
|
||||
)
|
||||
def testJaxArrayRoundTrip(self, shape, dtype, gpu):
|
||||
if xla_extension_version < 191:
|
||||
self.skipTest("Need xla_extension_version >= 191")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
if gpu and jax.default_backend() == "cpu":
|
||||
raise unittest.SkipTest("Skipping GPU test case on CPU")
|
||||
device = jax.devices("gpu" if gpu else "cpu")[0]
|
||||
x = jax.device_put(np, device)
|
||||
y = jax.dlpack.from_dlpack(x)
|
||||
self.assertEqual(y.device(), device)
|
||||
self.assertAllClose(np.astype(x.dtype), y)
|
||||
# Test we can create multiple arrays
|
||||
z = jax.dlpack.from_dlpack(x)
|
||||
self.assertEqual(z.device(), device)
|
||||
self.assertAllClose(np.astype(x.dtype), z)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
|
@ -137,6 +137,30 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
z = jax.jit(lambda x: x + 1)(y)
|
||||
self.assertAllClose(x_np + dtype(1), z)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
|
||||
def testTorchToJaxArray(self, shape, dtype):
|
||||
if xla_extension_version < 191:
|
||||
self.skipTest("Need xla_extension_version >= 191")
|
||||
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
|
||||
jnp.complex128]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x_np = rng(shape, dtype)
|
||||
if dtype == jnp.bfloat16:
|
||||
x = torch.tensor(x_np.view(jnp.int16)).view(torch.bfloat16)
|
||||
else:
|
||||
x = torch.tensor(x_np)
|
||||
x = x.cuda() if jtu.device_under_test() == "gpu" else x
|
||||
x = x.contiguous()
|
||||
y = jax.dlpack.from_dlpack(x)
|
||||
self.assertAllClose(x_np, y)
|
||||
|
||||
# Verify the resulting value can be passed to a jit computation.
|
||||
z = jax.jit(lambda x: x + 1)(y)
|
||||
self.assertAllClose(x_np + dtype(1), z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user