Fix dlpack type signatures to match Array API spec.

Fixes https://github.com/google/jax/issues/17510
This commit is contained in:
Peter Hawkins 2023-09-08 09:18:38 -04:00
parent 601d67ae66
commit 3a4b60b48c
2 changed files with 6 additions and 4 deletions

View File

@ -15,6 +15,7 @@
from __future__ import annotations
from collections import defaultdict
import enum
import math
import operator as op
import numpy as np
@ -371,13 +372,13 @@ class ArrayImpl(basearray.Array):
def __array__(self, dtype=None, context=None):
return np.asarray(self._value, dtype=dtype)
def __dlpack__(self, stream: int | None = None):
def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)
def __dlpack_device__(self) -> tuple[int, int]:
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")

View File

@ -15,6 +15,7 @@
from __future__ import annotations
import enum
from typing import Any
from jax import numpy as jnp
from jax._src import array
@ -38,7 +39,7 @@ class DLDeviceType(enum.IntEnum):
def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | None = None):
stream: int | Any | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Takes ownership of the contents of ``x``; leaves ``x`` in an invalid/deleted
@ -108,7 +109,7 @@ def from_dlpack(external_array):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream)
dlpack = external_array.__dlpack__(stream=stream)
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, device, stream))