mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 09:26:07 +00:00

When take_ownership is true, the original buffer is marked as deleted and enforced that JAX won't attempt to read or write the buffer. This provides better error checking but at the cost of one more C++ API and two more C APIs. The same semantic can be achieved by not using take_ownership and being careful. Therefore we decided to remove take_ownership support in DLPack. PiperOrigin-RevId: 572278488
139 lines
4.7 KiB
Python
139 lines
4.7 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
from typing import Any
|
|
import warnings
|
|
|
|
from jax import numpy as jnp
|
|
from jax._src import array
|
|
from jax._src import xla_bridge
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib import xla_extension_version
|
|
from jax._src.typing import Array
|
|
|
|
|
|
SUPPORTED_DTYPES = frozenset({
|
|
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
|
|
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
|
|
jnp.float64, jnp.complex64, jnp.complex128})
|
|
|
|
|
|
# Mirror of dlpack.h enum
|
|
class DLDeviceType(enum.IntEnum):
|
|
kDLCPU = 1
|
|
kDLCUDA = 2
|
|
kDLROCM = 10
|
|
|
|
|
|
def to_dlpack(x: Array, take_ownership: bool = False,
|
|
stream: int | Any | None = None):
|
|
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
|
|
|
Args:
|
|
x: a :class:`~jax.Array`, on either CPU or GPU.
|
|
take_ownership: Deprecated for xla_extension_version greater or equal than
|
|
204. It is a no-op to set take_ownership. Will be deleted in 01/2024. For
|
|
xla_extension_version less than 204, if ``True``, the JAX buffer acts as
|
|
if it were deleted.
|
|
stream: optional platform-dependent stream to wait on until the buffer is
|
|
ready. This corresponds to the `stream` argument to ``__dlpack__``
|
|
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
|
|
"""
|
|
if not isinstance(x, array.ArrayImpl):
|
|
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
|
f"got {type(x)}")
|
|
assert len(x.devices()) == 1
|
|
if take_ownership:
|
|
warnings.warn(
|
|
"take_ownership in to_dlpack is deprecated and it is a no-op."
|
|
)
|
|
if xla_extension_version >= 204:
|
|
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
|
x.addressable_data(0), stream=stream
|
|
) # type: ignore
|
|
elif xla_extension_version >= 186:
|
|
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
|
x.addressable_data(0), take_ownership=take_ownership, stream=stream
|
|
) # type: ignore
|
|
else:
|
|
if stream is not None:
|
|
raise ValueError(
|
|
"passing `stream` argument to to_dlpack requires jaxlib >= 0.4.15")
|
|
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
|
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
|
|
|
|
|
|
|
|
def from_dlpack(external_array):
|
|
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
|
|
|
|
The returned :class:`~jax.Array` shares memory with ``external_array``.
|
|
|
|
Args:
|
|
external_array: an array object that has __dlpack__ and __dlpack_device__
|
|
methods, or a DLPack tensor on either CPU or GPU (legacy API).
|
|
|
|
Returns:
|
|
A jax.Array
|
|
"""
|
|
if hasattr(external_array, "__dlpack__") and xla_extension_version >= 191:
|
|
dl_device_type, device_id = external_array.__dlpack_device__()
|
|
try:
|
|
device_platform = {
|
|
DLDeviceType.kDLCPU: "cpu",
|
|
DLDeviceType.kDLCUDA: "cuda",
|
|
DLDeviceType.kDLROCM: "rocm",
|
|
}[dl_device_type]
|
|
except TypeError:
|
|
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
|
|
# TypeError.
|
|
raise TypeError(
|
|
"Array passed to from_dlpack is on unsupported device type "
|
|
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
|
|
|
|
backend = xla_bridge.get_backend(device_platform)
|
|
device = backend.device_from_local_hardware_id(device_id)
|
|
try:
|
|
stream = device.get_stream_for_external_ready_events()
|
|
except xla_client.XlaRuntimeError as err: # type: ignore
|
|
if "UNIMPLEMENTED" in str(err):
|
|
stream = None
|
|
else:
|
|
raise
|
|
dlpack = external_array.__dlpack__(stream=stream)
|
|
|
|
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
|
dlpack, device, stream))
|
|
else:
|
|
# Legacy path
|
|
dlpack = external_array
|
|
cpu_backend = xla_bridge.get_backend("cpu")
|
|
try:
|
|
gpu_backend = xla_bridge.get_backend("cuda")
|
|
except RuntimeError:
|
|
gpu_backend = None
|
|
|
|
# Try ROCm if CUDA backend not found
|
|
if gpu_backend is None:
|
|
try:
|
|
gpu_backend = xla_bridge.get_backend("rocm")
|
|
except RuntimeError:
|
|
gpu_backend = None
|
|
|
|
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
|
dlpack, cpu_backend, gpu_backend))
|