mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Improve type annotations for jax.Array methods
This commit is contained in:
parent
0c543aef1d
commit
dd697a9abc
@ -14,17 +14,29 @@
|
||||
import abc
|
||||
from collections.abc import Callable, Sequence
|
||||
from types import ModuleType
|
||||
from typing import Any, Union
|
||||
from typing import Any, Protocol, Union, runtime_checkable
|
||||
import numpy as np
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
|
||||
# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py.
|
||||
# We redefine these here to prevent circular imports.
|
||||
@runtime_checkable
|
||||
class SupportsDType(Protocol):
|
||||
@property
|
||||
def dtype(self) -> np.dtype: ...
|
||||
DTypeLike = Union[str, type[Any], np.dtype, SupportsDType]
|
||||
|
||||
Axis = Union[int, Sequence[int], None]
|
||||
Shard = Any
|
||||
|
||||
# TODO: alias this to xla_client.Traceback
|
||||
Device = Any
|
||||
Traceback = Any
|
||||
|
||||
# TODO(jakevdp): fix import cycles and import this from jax._src.lax.
|
||||
PrecisionLike = Any
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
aval: Any
|
||||
@ -117,72 +129,89 @@ class Array(abc.ABC):
|
||||
def __release_buffer__(self, view: memoryview) -> None: ...
|
||||
|
||||
# np.ndarray methods:
|
||||
def all(self, axis: int | Sequence[int] | None = None, out=None,
|
||||
keepdims=None, *, where: ArrayLike | None = ...) -> Array: ...
|
||||
def any(self, axis: int | Sequence[int] | None = None, out=None,
|
||||
keepdims=None, *, where: ArrayLike | None = ...) -> Array: ...
|
||||
def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ...
|
||||
def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ...
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
||||
def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def astype(self, dtype) -> Array: ...
|
||||
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
||||
def clip(self, min=None, max=None, out=None) -> Array: ...
|
||||
def compress(self, condition, axis: int | None = None, out=None) -> Array: ...
|
||||
def all(self, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
|
||||
def any(self: Array, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
|
||||
def argmax(self: Array, axis: int | None = None, out: None = None,
|
||||
keepdims: bool | None = None) -> Array: ...
|
||||
def argmin(self, axis: int | None = None, out: None = None,
|
||||
keepdims: bool | None = None) -> Array: ...
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order: None = None) -> Array: ...
|
||||
def argsort(self, axis: int | None = -1, kind='quicksort', order: None = None) -> Array: ...
|
||||
def astype(self, dtype: DTypeLike | None = None, max: ArrayLike | None = None) -> Array: ...
|
||||
def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ...
|
||||
def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ...
|
||||
def compress(self, condition: ArrayLike,
|
||||
axis: int | None = None, *, out: None = None,
|
||||
size: int | None = None, fill_value: ArrayLike = 0) -> Array: ...
|
||||
def conj(self) -> Array: ...
|
||||
def conjugate(self) -> Array: ...
|
||||
def copy(self) -> Array: ...
|
||||
def cumprod(self, axis: int | Sequence[int] | None = None,
|
||||
dtype=None, out=None) -> Array: ...
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
|
||||
def cumsum(self, axis: int | Sequence[int] | None = None,
|
||||
dtype=None, out=None) -> Array: ...
|
||||
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
|
||||
def dot(self, b, *, precision=None) -> Array: ...
|
||||
def flatten(self) -> Array: ...
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
|
||||
def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ...
|
||||
def dot(self, b: ArrayLike, *, precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array: ...
|
||||
def flatten(self, order: str = "C") -> Array: ...
|
||||
@property
|
||||
def imag(self) -> Array: ...
|
||||
def item(self, *args) -> Any: ...
|
||||
def max(self, axis: int | Sequence[int] | None = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def mean(self, axis: int | Sequence[int] | None = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None,) -> Array: ...
|
||||
def min(self, axis: int | Sequence[int] | None = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def item(self, *args: int) -> Any: ...
|
||||
def max(self, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array: ...
|
||||
def mean(self, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array: ...
|
||||
def min(self, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array: ...
|
||||
@property
|
||||
def nbytes(self) -> int: ...
|
||||
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
|
||||
def prod(self, axis: int | Sequence[int] | None = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def ptp(self, axis: int | Sequence[int] | None = None, out=None,
|
||||
keepdims=False,) -> Array: ...
|
||||
def ravel(self, order='C') -> Array: ...
|
||||
def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None,
|
||||
size: int | None = None,) -> tuple[Array, ...]: ...
|
||||
def prod(self, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None,
|
||||
promote_integers: bool = True) -> Array: ...
|
||||
def ptp(self, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False) -> Array: ...
|
||||
def ravel(self, order: str = 'C') -> Array: ...
|
||||
@property
|
||||
def real(self) -> Array: ...
|
||||
def repeat(self, repeats, axis: int | None = None, *,
|
||||
total_repeat_length=None) -> Array: ...
|
||||
def reshape(self, *args, order='C') -> Array: ...
|
||||
def round(self, decimals=0, out=None) -> Array: ...
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
||||
def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def repeat(self, repeats: ArrayLike, axis: int | None = None, *,
|
||||
total_repeat_length: int | None = None) -> Array: ...
|
||||
def reshape(self, *args: Any, order: str = "C") -> Array: ...
|
||||
def round(self, decimals: int = 0, out: None = None) -> Array: ...
|
||||
def searchsorted(self, v: ArrayLike, side: str = 'left',
|
||||
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ...
|
||||
def sort(self, axis: int | None = -1, *, kind: None = None,
|
||||
order: None = None, stable: bool = True, descending: bool = False) -> Array: ...
|
||||
def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ...
|
||||
def std(self, axis: int | Sequence[int] | None = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
def sum(self, axis: int | Sequence[int] | None = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def std(self, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
|
||||
def sum(self, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ...
|
||||
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
|
||||
def take(self, indices, axis: int | None = None, out=None,
|
||||
mode=None) -> Array: ...
|
||||
def tobytes(self, order='C') -> bytes: ...
|
||||
def take(self, indices: ArrayLike, axis: int | None = None, out: None = None,
|
||||
mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False,
|
||||
fill_value: StaticScalar | None = None) -> Array: ...
|
||||
def tobytes(self, order: str = 'C') -> bytes: ...
|
||||
def tolist(self) -> list[Any]: ...
|
||||
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
|
||||
out=None) -> Array: ...
|
||||
def transpose(self, *args) -> Array: ...
|
||||
def trace(self, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
|
||||
def transpose(self, *args: Any) -> Array: ...
|
||||
@property
|
||||
def T(self) -> Array: ...
|
||||
@property
|
||||
def mT(self) -> Array: ...
|
||||
def var(self, axis: int | Sequence[int] | None = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
def var(self, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
|
||||
def view(self, dtype=None, type=None) -> Array: ...
|
||||
|
||||
# Even though we don't always support the NumPy array protocol, e.g., for
|
||||
|
@ -58,7 +58,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
# functions, which can themselves handle instances from any of these classes.
|
||||
|
||||
|
||||
def _all(self: ArrayLike, axis: reductions.Axis = None, out: None = None,
|
||||
def _all(self: Array, axis: reductions.Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
|
||||
"""Test whether all array elements along a given axis evaluate to True.
|
||||
|
||||
@ -107,7 +107,8 @@ def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: No
|
||||
return lax_numpy.argsort(self, axis=axis, kind=kind, order=order,
|
||||
stable=stable, descending=descending)
|
||||
|
||||
def _astype(self: Array, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array:
|
||||
def _astype(self: Array, dtype: DTypeLike | None, copy: bool = False,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Copy the array and cast to a specified dtype.
|
||||
|
||||
This is implemented via :func:`jax.lax.convert_element_type`, which may
|
||||
@ -124,13 +125,12 @@ def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: s
|
||||
"""
|
||||
return lax_numpy.choose(self, choices=choices)
|
||||
|
||||
def _clip(number: ArrayLike,
|
||||
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
|
||||
def _clip(self: Array, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
|
||||
"""Return an array whose values are limited to a specified range.
|
||||
|
||||
Refer to :func:`jax.numpy.clip` for full documentation.
|
||||
"""
|
||||
return lax_numpy.clip(number, min=min, max=max)
|
||||
return lax_numpy.clip(self, min=min, max=max)
|
||||
|
||||
def _compress(self: Array, condition: ArrayLike,
|
||||
axis: int | None = None, *, out: None = None,
|
||||
@ -163,7 +163,7 @@ def _copy(self: Array) -> Array:
|
||||
"""
|
||||
return lax_numpy.copy(self)
|
||||
|
||||
def _cumprod(self: Array, /, axis: int | Sequence[int] | None = None,
|
||||
def _cumprod(self: Array, axis: int | Sequence[int] | None = None,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array:
|
||||
"""Return the cumulative product of the array.
|
||||
|
||||
@ -171,7 +171,7 @@ def _cumprod(self: Array, /, axis: int | Sequence[int] | None = None,
|
||||
"""
|
||||
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out)
|
||||
|
||||
def _cumsum(self: Array, /, axis: int | Sequence[int] | None = None,
|
||||
def _cumsum(self: Array, axis: int | Sequence[int] | None = None,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array:
|
||||
"""Return the cumulative sum of the array.
|
||||
|
||||
@ -258,9 +258,8 @@ def _nbytes_property(self: Array) -> int:
|
||||
"""Total bytes consumed by the elements of the array."""
|
||||
return np.size(self) * dtypes.dtype(self, canonicalize=True).itemsize
|
||||
|
||||
def _nonzero(self: Array, *, size: int | None = None,
|
||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||
) -> tuple[Array, ...]:
|
||||
def _nonzero(self: Array, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None,
|
||||
size: int | None = None) -> tuple[Array, ...]:
|
||||
"""Return indices of nonzero elements of an array.
|
||||
|
||||
Refer to :func:`jax.numpy.nonzero` for the full documentation.
|
||||
|
Loading…
x
Reference in New Issue
Block a user