mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update from_dlpack
to match array API 2023
This commit is contained in:
parent
fed7efd730
commit
2b1c3deee2
@ -18,13 +18,14 @@ import enum
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from jax._src.api import device_put
|
||||
from jax import numpy as jnp
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
|
||||
# A set of dtypes that dlpack supports.
|
||||
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
|
||||
@ -82,16 +83,111 @@ def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
x.addressable_data(0), stream=stream
|
||||
) # type: ignore
|
||||
|
||||
def _place_array(_arr, device, dlpack_device, copy):
|
||||
if device and dlpack_device != device:
|
||||
if copy is not None and not copy:
|
||||
raise ValueError(
|
||||
f"Specified {device=} which requires a copy since the source device "
|
||||
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
|
||||
"copy=None to perform the requested operation."
|
||||
)
|
||||
else:
|
||||
return device_put(_arr, device)
|
||||
if copy:
|
||||
return jnp.array(_arr, copy=True)
|
||||
return _arr
|
||||
|
||||
def from_dlpack(external_array):
|
||||
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
|
||||
copy: bool | None = None):
|
||||
preferred_platform = getattr(device, "platform", None)
|
||||
if device and preferred_platform == "gpu":
|
||||
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
|
||||
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
gpu_backend = None
|
||||
|
||||
if preferred_platform in {"cuda", "rocm"}:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend(preferred_platform)
|
||||
except RuntimeError:
|
||||
raise TypeError(
|
||||
f"A {str.upper(preferred_platform)} device was specified, however no "
|
||||
f"{str.upper(preferred_platform)} backend was found."
|
||||
)
|
||||
|
||||
if preferred_platform is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("cuda")
|
||||
except RuntimeError:
|
||||
pass
|
||||
# Try ROCm if CUDA backend not found
|
||||
if gpu_backend is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("rocm")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)) # type: ignore
|
||||
dlpack_device, = _arr.devices()
|
||||
return _place_array(_arr, device, dlpack_device, copy)
|
||||
|
||||
def _from_dlpack(external_array, device: xla_client.Device | None = None,
|
||||
copy: bool | None = None):
|
||||
dl_device_type, device_id = external_array.__dlpack_device__()
|
||||
try:
|
||||
dl_device_platform = {
|
||||
DLDeviceType.kDLCPU: "cpu",
|
||||
DLDeviceType.kDLCUDA: "cuda",
|
||||
DLDeviceType.kDLROCM: "rocm",
|
||||
}[dl_device_type]
|
||||
except TypeError:
|
||||
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
|
||||
# TypeError.
|
||||
raise TypeError(
|
||||
"Array passed to from_dlpack is on unsupported device type "
|
||||
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
|
||||
|
||||
backend = xla_bridge.get_backend(dl_device_platform)
|
||||
dlpack_device = backend.device_from_local_hardware_id(device_id)
|
||||
try:
|
||||
stream = dlpack_device.get_stream_for_external_ready_events()
|
||||
except xla_client.XlaRuntimeError as err: # type: ignore
|
||||
if "UNIMPLEMENTED" in str(err):
|
||||
stream = None
|
||||
else:
|
||||
raise
|
||||
dlpack = external_array.__dlpack__(stream=stream)
|
||||
|
||||
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, dlpack_device, stream))
|
||||
return _place_array(_arr, device, dlpack_device, copy)
|
||||
|
||||
def from_dlpack(external_array,
|
||||
device: xla_client.Device | Sharding | None = None,
|
||||
copy: bool | None = None):
|
||||
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
|
||||
|
||||
The returned :class:`~jax.Array` shares memory with ``external_array``.
|
||||
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
|
||||
device transfer or copy was requested.
|
||||
|
||||
Args:
|
||||
external_array: an array object that has __dlpack__ and __dlpack_device__
|
||||
external_array: An array object that has __dlpack__ and __dlpack_device__
|
||||
methods, or a DLPack tensor on either CPU or GPU (legacy API).
|
||||
|
||||
device: The (optional) :py:class:`Device`, representing the device on which
|
||||
the returned array should be placed. If given, then the result is committed
|
||||
to the device. If unspecified, the resulting array will be unpacked onto the
|
||||
same device it originated from. Setting ``device`` to a device different from
|
||||
the source of ``external_array`` will require a copy, meaning ``copy`` must be
|
||||
set to either ``True`` or ``None``.
|
||||
|
||||
copy: An (optional) boolean, controlling whether or not to a copy is performed.
|
||||
If ``copy=True`` then a copy is always performed, even if unpacked onto the
|
||||
same device. If ``copy=False`` then the copy is never peformed and will raise
|
||||
an error if necessary. When ``copy=None`` then a copy may be performed if
|
||||
needed for a device transfer.
|
||||
|
||||
Returns:
|
||||
A jax.Array
|
||||
|
||||
@ -102,49 +198,16 @@ def from_dlpack(external_array):
|
||||
is later modified in-place, it may lead to undefined behavior when using
|
||||
the associated JAX array.
|
||||
"""
|
||||
if isinstance(device, Sharding):
|
||||
device_set = device.device_set
|
||||
if len(device_set) > 1:
|
||||
raise ValueError(
|
||||
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
|
||||
f"a Sharding with {len(device_set)} devices was provided."
|
||||
)
|
||||
device, = device_set
|
||||
if hasattr(external_array, "__dlpack__"):
|
||||
dl_device_type, device_id = external_array.__dlpack_device__()
|
||||
try:
|
||||
device_platform = {
|
||||
DLDeviceType.kDLCPU: "cpu",
|
||||
DLDeviceType.kDLCUDA: "cuda",
|
||||
DLDeviceType.kDLROCM: "rocm",
|
||||
}[dl_device_type]
|
||||
except TypeError:
|
||||
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
|
||||
# TypeError.
|
||||
raise TypeError(
|
||||
"Array passed to from_dlpack is on unsupported device type "
|
||||
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
|
||||
return _from_dlpack(external_array, device, copy)
|
||||
|
||||
backend = xla_bridge.get_backend(device_platform)
|
||||
device = backend.device_from_local_hardware_id(device_id)
|
||||
try:
|
||||
stream = device.get_stream_for_external_ready_events()
|
||||
except xla_client.XlaRuntimeError as err: # type: ignore
|
||||
if "UNIMPLEMENTED" in str(err):
|
||||
stream = None
|
||||
else:
|
||||
raise
|
||||
dlpack = external_array.__dlpack__(stream=stream)
|
||||
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, device, stream))
|
||||
else:
|
||||
# Legacy path
|
||||
dlpack = external_array
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("cuda")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
# Try ROCm if CUDA backend not found
|
||||
if gpu_backend is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("rocm")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend))
|
||||
# Legacy path
|
||||
return _legacy_from_dlpack(external_array, device, copy)
|
||||
|
@ -2442,9 +2442,10 @@ def fromiter(*args, **kwargs):
|
||||
is later modified in-place, it may lead to undefined behavior when using
|
||||
the associated JAX array.
|
||||
""")
|
||||
def from_dlpack(x: Any) -> Array:
|
||||
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
|
||||
copy: bool | None = None) -> Array:
|
||||
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
|
||||
return from_dlpack(x)
|
||||
return from_dlpack(x, device=device, copy=copy)
|
||||
|
||||
@util.implements(np.fromfunction)
|
||||
def fromfunction(function: Callable[..., Array], shape: Any,
|
||||
|
@ -12,9 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding import Sharding
|
||||
|
||||
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
|
||||
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
|
||||
@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None):
|
||||
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
|
||||
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
|
||||
|
||||
def from_dlpack(x, /):
|
||||
return jnp.from_dlpack(x)
|
||||
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
|
||||
return jnp.from_dlpack(x, device=device, copy=copy)
|
||||
|
||||
def full(shape, fill_value, *, dtype=None, device=None):
|
||||
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)
|
||||
|
@ -353,7 +353,8 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
|
||||
def from_dlpack(x: Any) -> Array: ...
|
||||
def from_dlpack(x: Any, /, *, device: _Device | None = None,
|
||||
copy: builtins.bool | None = None) -> Array: ...
|
||||
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
|
||||
count: int = ..., offset: int = ...) -> Array: ...
|
||||
def fromfile(*args, **kwargs): ...
|
||||
|
@ -174,12 +174,21 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=numpy_dtypes,
|
||||
copy=[False, True],
|
||||
)
|
||||
def testNumpyToJax(self, shape, dtype):
|
||||
def testNumpyToJax(self, shape, dtype, copy):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x_np = rng(shape, dtype)
|
||||
x_jax = jnp.from_dlpack(x_np)
|
||||
self.assertAllClose(x_np, x_jax)
|
||||
device = jax.devices()[0]
|
||||
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
|
||||
if jax.default_backend() == 'gpu' and not copy:
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Specified .* which requires a copy",
|
||||
_from_dlpack
|
||||
)
|
||||
else:
|
||||
self.assertAllClose(x_np, _from_dlpack())
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user