rocm_jax/jax/_src/basearray.pyi
2025-01-16 18:56:52 -08:00

305 lines
13 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any, Protocol, Union, runtime_checkable
import numpy as np
from jax._src.sharding import Sharding
from jax._src.partition_spec import PartitionSpec
# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py.
# We redefine these here to prevent circular imports.
@runtime_checkable
class SupportsDType(Protocol):
@property
def dtype(self) -> np.dtype: ...
DTypeLike = Union[str, type[Any], np.dtype, SupportsDType]
Axis = Union[int, Sequence[int], None]
Shard = Any
# TODO: alias this to xla_client.Traceback
Device = Any
Traceback = Any
# TODO(jakevdp): fix import cycles and import this from jax._src.lax.
PrecisionLike = Any
class Array(abc.ABC):
aval: Any
@property
def dtype(self) -> np.dtype: ...
@property
def ndim(self) -> int: ...
@property
def size(self) -> int: ...
@property
def itemsize(self) -> int: ...
@property
def shape(self) -> tuple[int, ...]: ...
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
def __array_namespace__(self, *, api_version: None | str = ...) -> ModuleType: ...
def __getitem__(self, key) -> Array: ...
def __setitem__(self, key, value) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Any: ...
def __reversed__(self) -> Any: ...
def __round__(self, ndigits=None) -> Array: ...
# Comparisons
# these return bool for object, so ignore override errors.
def __lt__(self, other) -> Array: ...
def __le__(self, other) -> Array: ...
def __eq__(self, other) -> Array: ... # type: ignore[override]
def __ne__(self, other) -> Array: ... # type: ignore[override]
def __gt__(self, other) -> Array: ...
def __ge__(self, other) -> Array: ...
# Unary arithmetic
def __neg__(self) -> Array: ...
def __pos__(self) -> Array: ...
def __abs__(self) -> Array: ...
def __invert__(self) -> Array: ...
# Binary arithmetic
def __add__(self, other) -> Array: ...
def __sub__(self, other) -> Array: ...
def __mul__(self, other) -> Array: ...
def __matmul__(self, other) -> Array: ...
def __truediv__(self, other) -> Array: ...
def __floordiv__(self, other) -> Array: ...
def __mod__(self, other) -> Array: ...
def __divmod__(self, other) -> tuple[Array, Array]: ...
def __pow__(self, other) -> Array: ...
def __lshift__(self, other) -> Array: ...
def __rshift__(self, other) -> Array: ...
def __and__(self, other) -> Array: ...
def __xor__(self, other) -> Array: ...
def __or__(self, other) -> Array: ...
def __radd__(self, other) -> Array: ...
def __rsub__(self, other) -> Array: ...
def __rmul__(self, other) -> Array: ...
def __rmatmul__(self, other) -> Array: ...
def __rtruediv__(self, other) -> Array: ...
def __rfloordiv__(self, other) -> Array: ...
def __rmod__(self, other) -> Array: ...
def __rdivmod__(self, other) -> Array: ...
def __rpow__(self, other) -> Array: ...
def __rlshift__(self, other) -> Array: ...
def __rrshift__(self, other) -> Array: ...
def __rand__(self, other) -> Array: ...
def __rxor__(self, other) -> Array: ...
def __ror__(self, other) -> Array: ...
def __bool__(self) -> bool: ...
def __complex__(self) -> complex: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __index__(self) -> int: ...
def __buffer__(self, flags: int) -> memoryview: ...
def __release_buffer__(self, view: memoryview) -> None: ...
# np.ndarray methods:
def all(self, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
def any(self, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
def argmax(self, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array: ...
def argmin(self, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array: ...
def argpartition(self, kth: int, axis: int = -1) -> Array: ...
def argsort(self, axis: int | None = -1, *, kind: None = None, order: None = None,
stable: bool = True, descending: bool = False) -> Array: ...
def astype(self, dtype: DTypeLike | None = None, copy: bool = False,
device: Device | Sharding | None = None) -> Array: ...
def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ...
def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ...
def compress(self, condition: ArrayLike,
axis: int | None = None, *, out: None = None,
size: int | None = None, fill_value: ArrayLike = 0) -> Array: ...
def conj(self) -> Array: ...
def conjugate(self) -> Array: ...
def copy(self) -> Array: ...
def cumprod(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def cumsum(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array: ...
def flatten(self, order: str = "C") -> Array: ...
@property
def imag(self) -> Array: ...
def item(self, *args: int) -> Any: ...
def max(self, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
def mean(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array: ...
def min(self, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
@property
def nbytes(self) -> int: ...
def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None,
size: int | None = None) -> tuple[Array, ...]: ...
def prod(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array: ...
def ptp(self, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array: ...
def ravel(self, order: str = 'C') -> Array: ...
@property
def real(self) -> Array: ...
def repeat(self, repeats: ArrayLike, axis: int | None = None, *,
total_repeat_length: int | None = None) -> Array: ...
def reshape(self, *args: Any, order: str = "C") -> Array: ...
def round(self, decimals: int = 0, out: None = None) -> Array: ...
def searchsorted(self, v: ArrayLike, side: str = 'left',
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ...
def sort(self, axis: int | None = -1, *, kind: None = None,
order: None = None, stable: bool = True, descending: bool = False) -> Array: ...
def squeeze(self, axis: Axis = None) -> Array: ...
def std(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
def sum(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ...
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
def take(self, indices: ArrayLike, axis: int | None = None, out: None = None,
mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False,
fill_value: StaticScalar | None = None) -> Array: ...
def tobytes(self, order: str = 'C') -> bytes: ...
def tolist(self) -> list[Any]: ...
def trace(self, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1,
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def transpose(self, *args: Any) -> Array: ...
@property
def T(self) -> Array: ...
@property
def mT(self) -> Array: ...
def var(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
def view(self, dtype: DTypeLike | None = None, type: None = None) -> Array: ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self, dtype: np.dtype | None = ...,
copy: bool | None = ...) -> np.ndarray: ...
def __dlpack__(self) -> Any: ...
# JAX extensions
@property
def at(self) -> _IndexUpdateHelper: ...
@property
def weak_type(self) -> bool: ...
# Methods defined on ArrayImpl, but not on Tracers
def addressable_data(self, index: int) -> Array: ...
def block_until_ready(self) -> Array: ...
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def devices(self) -> set[Device]: ...
@property
def sharding(self) -> Sharding: ...
@property
def committed(self) -> bool: ...
@property
def device(self) -> Device | Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
def is_deleted(self) -> bool: ...
@property
def is_fully_addressable(self) -> bool: ...
@property
def is_fully_replicated(self) -> bool: ...
def on_device_size_in_bytes(self) -> int: ...
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
def to_device(self, device: Device | Sharding, *,
stream: int | Any | None = ...) -> Array: ...
StaticScalar = Union[
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
]
ArrayLike = Union[
Array, # JAX array type
np.ndarray, # NumPy array type
StaticScalar, # valid scalars
]
# TODO: restructure to avoid re-defining this here?
# from jax._src.numpy.lax_numpy import _IndexUpdateHelper
class _IndexUpdateHelper:
def __getitem__(self, index: Any) -> _IndexUpdateRef: ...
class _IndexUpdateRef:
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
mode: str | None = None, fill_value: StaticScalar | None = None,
out_spec: Sharding | PartitionSpec | None = None) -> Array: ...
def set(self, values: Any,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
def add(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def subtract(self, values: Any, *, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def mul(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def multiply(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def divide(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def power(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def min(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def max(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...