rocm_jax/jax/_src/dlpack.py
Jieying Luo 269d7ce5c1 Remove take_ownership support in DLPack.
When take_ownership is true, the original buffer is marked as deleted and enforced that JAX won't attempt to read or write the buffer. This provides better error checking but at the cost of one more C++ API and two more C APIs. The same semantic can be achieved by not using take_ownership and being careful. Therefore we decided to remove take_ownership support in DLPack.

PiperOrigin-RevId: 572278488
2023-10-10 09:43:02 -07:00

139 lines
4.7 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
from typing import Any
import warnings
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
SUPPORTED_DTYPES = frozenset({
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
jnp.float64, jnp.complex64, jnp.complex128})
# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10
def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | Any | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Args:
x: a :class:`~jax.Array`, on either CPU or GPU.
take_ownership: Deprecated for xla_extension_version greater or equal than
204. It is a no-op to set take_ownership. Will be deleted in 01/2024. For
xla_extension_version less than 204, if ``True``, the JAX buffer acts as
if it were deleted.
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to ``__dlpack__``
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
if take_ownership:
warnings.warn(
"take_ownership in to_dlpack is deprecated and it is a no-op."
)
if xla_extension_version >= 204:
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), stream=stream
) # type: ignore
elif xla_extension_version >= 186:
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership, stream=stream
) # type: ignore
else:
if stream is not None:
raise ValueError(
"passing `stream` argument to to_dlpack requires jaxlib >= 0.4.15")
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
def from_dlpack(external_array):
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
The returned :class:`~jax.Array` shares memory with ``external_array``.
Args:
external_array: an array object that has __dlpack__ and __dlpack_device__
methods, or a DLPack tensor on either CPU or GPU (legacy API).
Returns:
A jax.Array
"""
if hasattr(external_array, "__dlpack__") and xla_extension_version >= 191:
dl_device_type, device_id = external_array.__dlpack_device__()
try:
device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
backend = xla_bridge.get_backend(device_platform)
device = backend.device_from_local_hardware_id(device_id)
try:
stream = device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, device, stream))
else:
# Legacy path
dlpack = external_array
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
gpu_backend = None
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
gpu_backend = None
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))