mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Fix dlpack type signatures to match Array API spec.
Fixes https://github.com/google/jax/issues/17510
This commit is contained in:
parent
601d67ae66
commit
3a4b60b48c
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import enum
|
||||
import math
|
||||
import operator as op
|
||||
import numpy as np
|
||||
@ -371,13 +372,13 @@ class ArrayImpl(basearray.Array):
|
||||
def __array__(self, dtype=None, context=None):
|
||||
return np.asarray(self._value, dtype=dtype)
|
||||
|
||||
def __dlpack__(self, stream: int | None = None):
|
||||
def __dlpack__(self, *, stream: int | Any | None = None):
|
||||
if len(self._arrays) != 1:
|
||||
raise ValueError("__dlpack__ only supported for unsharded arrays.")
|
||||
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
||||
return to_dlpack(self, stream=stream)
|
||||
|
||||
def __dlpack_device__(self) -> tuple[int, int]:
|
||||
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
|
||||
if len(self._arrays) != 1:
|
||||
raise ValueError("__dlpack__ only supported for unsharded arrays.")
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax._src import array
|
||||
@ -38,7 +39,7 @@ class DLDeviceType(enum.IntEnum):
|
||||
|
||||
|
||||
def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
stream: int | None = None):
|
||||
stream: int | Any | None = None):
|
||||
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
||||
|
||||
Takes ownership of the contents of ``x``; leaves ``x`` in an invalid/deleted
|
||||
@ -108,7 +109,7 @@ def from_dlpack(external_array):
|
||||
stream = None
|
||||
else:
|
||||
raise
|
||||
dlpack = external_array.__dlpack__(stream)
|
||||
dlpack = external_array.__dlpack__(stream=stream)
|
||||
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, device, stream))
|
||||
|
Loading…
x
Reference in New Issue
Block a user