Update from_dlpack to match array API 2023

This commit is contained in:
Meekail Zain 2024-04-04 22:51:25 +00:00
parent fed7efd730
commit 2b1c3deee2
5 changed files with 134 additions and 57 deletions

View File

@ -18,13 +18,14 @@ import enum
from typing import Any
import warnings
from jax._src.api import device_put
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
from jax._src.sharding import Sharding
# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@ -82,16 +83,111 @@ def to_dlpack(x: Array, take_ownership: bool = False,
x.addressable_data(0), stream=stream
) # type: ignore
def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr
def from_dlpack(external_array):
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
copy: bool | None = None):
preferred_platform = getattr(device, "platform", None)
if device and preferred_platform == "gpu":
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
cpu_backend = xla_bridge.get_backend("cpu")
gpu_backend = None
if preferred_platform in {"cuda", "rocm"}:
try:
gpu_backend = xla_bridge.get_backend(preferred_platform)
except RuntimeError:
raise TypeError(
f"A {str.upper(preferred_platform)} device was specified, however no "
f"{str.upper(preferred_platform)} backend was found."
)
if preferred_platform is None:
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
pass
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
pass
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)) # type: ignore
dlpack_device, = _arr.devices()
return _place_array(_arr, device, dlpack_device, copy)
def _from_dlpack(external_array, device: xla_client.Device | None = None,
copy: bool | None = None):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
backend = xla_bridge.get_backend(dl_device_platform)
dlpack_device = backend.device_from_local_hardware_id(device_id)
try:
stream = dlpack_device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))
return _place_array(_arr, device, dlpack_device, copy)
def from_dlpack(external_array,
device: xla_client.Device | Sharding | None = None,
copy: bool | None = None):
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
The returned :class:`~jax.Array` shares memory with ``external_array``.
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
device transfer or copy was requested.
Args:
external_array: an array object that has __dlpack__ and __dlpack_device__
external_array: An array object that has __dlpack__ and __dlpack_device__
methods, or a DLPack tensor on either CPU or GPU (legacy API).
device: The (optional) :py:class:`Device`, representing the device on which
the returned array should be placed. If given, then the result is committed
to the device. If unspecified, the resulting array will be unpacked onto the
same device it originated from. Setting ``device`` to a device different from
the source of ``external_array`` will require a copy, meaning ``copy`` must be
set to either ``True`` or ``None``.
copy: An (optional) boolean, controlling whether or not to a copy is performed.
If ``copy=True`` then a copy is always performed, even if unpacked onto the
same device. If ``copy=False`` then the copy is never peformed and will raise
an error if necessary. When ``copy=None`` then a copy may be performed if
needed for a device transfer.
Returns:
A jax.Array
@ -102,49 +198,16 @@ def from_dlpack(external_array):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
"""
if isinstance(device, Sharding):
device_set = device.device_set
if len(device_set) > 1:
raise ValueError(
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
f"a Sharding with {len(device_set)} devices was provided."
)
device, = device_set
if hasattr(external_array, "__dlpack__"):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
return _from_dlpack(external_array, device, copy)
backend = xla_bridge.get_backend(device_platform)
device = backend.device_from_local_hardware_id(device_id)
try:
stream = device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, device, stream))
else:
# Legacy path
dlpack = external_array
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
gpu_backend = None
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
gpu_backend = None
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
# Legacy path
return _legacy_from_dlpack(external_array, device, copy)

View File

@ -2442,9 +2442,10 @@ def fromiter(*args, **kwargs):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
""")
def from_dlpack(x: Any) -> Array:
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
copy: bool | None = None) -> Array:
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x)
return from_dlpack(x, device=device, copy=copy)
@util.implements(np.fromfunction)
def fromfunction(function: Callable[..., Array], shape: Any,

View File

@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None):
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
def from_dlpack(x, /):
return jnp.from_dlpack(x)
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
return jnp.from_dlpack(x, device=device, copy=copy)
def full(shape, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)

View File

@ -353,7 +353,8 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
def from_dlpack(x: Any) -> Array: ...
def from_dlpack(x: Any, /, *, device: _Device | None = None,
copy: builtins.bool | None = None) -> Array: ...
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
count: int = ..., offset: int = ...) -> Array: ...
def fromfile(*args, **kwargs): ...

View File

@ -174,12 +174,21 @@ class DLPackTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=all_shapes,
dtype=numpy_dtypes,
copy=[False, True],
)
def testNumpyToJax(self, shape, dtype):
def testNumpyToJax(self, shape, dtype, copy):
rng = jtu.rand_default(self.rng())
x_np = rng(shape, dtype)
x_jax = jnp.from_dlpack(x_np)
self.assertAllClose(x_np, x_jax)
device = jax.devices()[0]
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
if jax.default_backend() == 'gpu' and not copy:
self.assertRaisesRegex(
ValueError,
r"Specified .* which requires a copy",
_from_dlpack
)
else:
self.assertAllClose(x_np, _from_dlpack())
@jtu.sample_product(
shape=all_shapes,