rocm_jax/jax/_src/basearray.pyi
Peter Hawkins 706549a270 Relax type of the values argument to .at[...].set(...) and friends.
The `values` gets passed to jnp.asarray(...) and there are users that are currently passing not just ArrayLikes but also things that can be coerced to arrays (e.g., tuples). Relax the type annotation to better describe the world as it is.

We may wish to enforce that the arguments are arrays in the future, but we should add a dynamic check/warning in that case.

PiperOrigin-RevId: 517451811
2023-03-17 10:46:36 -07:00

249 lines
10 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Set
import numpy as np
from jax._src.sharding import Sharding
from jax._src import lib
Shard = Any
# TODO: alias this to xla_client.Traceback
Device = Any
Traceback = Any
class Array(abc.ABC):
dtype: np.dtype
ndim: int
size: int
itemsize: int
aval: Any
@property
def shape(self) -> Tuple[int, ...]: ...
@property
def sharding(self) -> Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
def __getitem__(self, key, indices_are_sorted=False,
unique_indices=False) -> Array: ...
def __setitem__(self, key, value) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Any: ...
def __reversed__(self) -> Any: ...
def __round__(self, ndigits=None) -> Array: ...
# Comparisons
# these return bool for object, so ignore override errors.
def __lt__(self, other) -> Array: ... # type: ignore[override]
def __le__(self, other) -> Array: ... # type: ignore[override]
def __eq__(self, other) -> Array: ... # type: ignore[override]
def __ne__(self, other) -> Array: ... # type: ignore[override]
def __gt__(self, other) -> Array: ... # type: ignore[override]
def __ge__(self, other) -> Array: ... # type: ignore[override]
# Unary arithmetic
def __neg__(self) -> Array: ...
def __pos__(self) -> Array: ...
def __abs__(self) -> Array: ...
def __invert__(self) -> Array: ...
# Binary arithmetic
def __add__(self, other) -> Array: ...
def __sub__(self, other) -> Array: ...
def __mul__(self, other) -> Array: ...
def __matmul__(self, other) -> Array: ...
def __truediv__(self, other) -> Array: ...
def __floordiv__(self, other) -> Array: ...
def __mod__(self, other) -> Array: ...
def __divmod__(self, other) -> Array: ...
def __pow__(self, other) -> Array: ...
def __lshift__(self, other) -> Array: ...
def __rshift__(self, other) -> Array: ...
def __and__(self, other) -> Array: ...
def __xor__(self, other) -> Array: ...
def __or__(self, other) -> Array: ...
def __radd__(self, other) -> Array: ...
def __rsub__(self, other) -> Array: ...
def __rmul__(self, other) -> Array: ...
def __rmatmul__(self, other) -> Array: ...
def __rtruediv__(self, other) -> Array: ...
def __rfloordiv__(self, other) -> Array: ...
def __rmod__(self, other) -> Array: ...
def __rdivmod__(self, other) -> Array: ...
def __rpow__(self, other) -> Array: ...
def __rlshift__(self, other) -> Array: ...
def __rrshift__(self, other) -> Array: ...
def __rand__(self, other) -> Array: ...
def __rxor__(self, other) -> Array: ...
def __ror__(self, other) -> Array: ...
def __bool__(self) -> bool: ...
def __complex__(self) -> complex: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __index__(self) -> int: ...
# np.ndarray methods:
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Array: ...
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Array: ...
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def astype(self, dtype) -> Array: ...
def choose(self, choices, out=None, mode='raise') -> Array: ...
def clip(self, min=None, max=None, out=None) -> Array: ...
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
def conj(self) -> Array: ...
def conjugate(self) -> Array: ...
def copy(self) -> Array: ...
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Array: ...
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Array: ...
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b, *, precision=None) -> Array: ...
def flatten(self) -> Array: ...
@property
def imag(self) -> Array: ...
def item(self, *args) -> Any: ...
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None,) -> Array: ...
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
@property
def nbytes(self) -> int: ...
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False,) -> Array: ...
def ravel(self, order='C') -> Array: ...
@property
def real(self) -> Array: ...
def repeat(self, repeats, axis: Optional[int] = None, *,
total_repeat_length=None) -> Array: ...
def reshape(self, *args, order='C') -> Array: ...
def round(self, decimals=0, out=None) -> Array: ...
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
def take(self, indices, axis: Optional[int] = None, out=None,
mode=None) -> Array: ...
def tobytes(self, order='C') -> bytes: ...
def tolist(self) -> List[Any]: ...
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None) -> Array: ...
def transpose(self, *args) -> Array: ...
@property
def T(self) -> Array: ...
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def view(self, dtype=None, type=None) -> Array: ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self) -> np.ndarray: ...
def __dlpack__(self) -> Any: ...
# JAX extensions
@property
def at(self) -> _IndexUpdateHelper: ...
@property
def weak_type(self) -> bool: ...
# Methods defined on ArrayImpl, but not on Tracers
def addressable_data(self, index: int) -> Array: ...
def block_until_ready(self) -> Array: ...
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def device(self) -> Device: ...
def devices(self) -> Set[Device]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
def is_deleted(self) -> bool: ...
@property
def is_fully_addressable(self) -> bool: ...
@property
def is_fully_replicated(self) -> bool: ...
def on_device_size_in_bytes(self) -> int: ...
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
@property
def device_buffers(self) -> Any: ...
ArrayLike = Union[
Array, # JAX array type
np.ndarray, # NumPy array type
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
]
# TODO: restructure to avoid re-defining this here?
# from jax._src.numpy.lax_numpy import _IndexUpdateHelper
class _IndexUpdateHelper:
def __getitem__(self, index: Any) -> _IndexUpdateRef: ...
class _IndexUpdateRef:
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ...
def set(self, values: Any,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ...
def add(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def mul(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def multiply(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def divide(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def power(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def min(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def max(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...