mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

The `values` gets passed to jnp.asarray(...) and there are users that are currently passing not just ArrayLikes but also things that can be coerced to arrays (e.g., tuples). Relax the type annotation to better describe the world as it is. We may wish to enforce that the arguments are arrays in the future, but we should add a dynamic check/warning in that case. PiperOrigin-RevId: 517451811
249 lines
10 KiB
Python
249 lines
10 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Set
|
|
import numpy as np
|
|
|
|
from jax._src.sharding import Sharding
|
|
from jax._src import lib
|
|
|
|
Shard = Any
|
|
|
|
# TODO: alias this to xla_client.Traceback
|
|
Device = Any
|
|
Traceback = Any
|
|
|
|
|
|
class Array(abc.ABC):
|
|
dtype: np.dtype
|
|
ndim: int
|
|
size: int
|
|
itemsize: int
|
|
aval: Any
|
|
|
|
@property
|
|
def shape(self) -> Tuple[int, ...]: ...
|
|
|
|
@property
|
|
def sharding(self) -> Sharding: ...
|
|
|
|
@property
|
|
def addressable_shards(self) -> Sequence[Shard]: ...
|
|
|
|
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
|
|
order=None):
|
|
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
|
|
" Use jax.numpy.array, or jax.numpy.zeros instead.")
|
|
|
|
def __getitem__(self, key, indices_are_sorted=False,
|
|
unique_indices=False) -> Array: ...
|
|
def __setitem__(self, key, value) -> None: ...
|
|
def __len__(self) -> int: ...
|
|
def __iter__(self) -> Any: ...
|
|
def __reversed__(self) -> Any: ...
|
|
def __round__(self, ndigits=None) -> Array: ...
|
|
|
|
# Comparisons
|
|
|
|
# these return bool for object, so ignore override errors.
|
|
def __lt__(self, other) -> Array: ... # type: ignore[override]
|
|
def __le__(self, other) -> Array: ... # type: ignore[override]
|
|
def __eq__(self, other) -> Array: ... # type: ignore[override]
|
|
def __ne__(self, other) -> Array: ... # type: ignore[override]
|
|
def __gt__(self, other) -> Array: ... # type: ignore[override]
|
|
def __ge__(self, other) -> Array: ... # type: ignore[override]
|
|
|
|
# Unary arithmetic
|
|
|
|
def __neg__(self) -> Array: ...
|
|
def __pos__(self) -> Array: ...
|
|
def __abs__(self) -> Array: ...
|
|
def __invert__(self) -> Array: ...
|
|
|
|
# Binary arithmetic
|
|
|
|
def __add__(self, other) -> Array: ...
|
|
def __sub__(self, other) -> Array: ...
|
|
def __mul__(self, other) -> Array: ...
|
|
def __matmul__(self, other) -> Array: ...
|
|
def __truediv__(self, other) -> Array: ...
|
|
def __floordiv__(self, other) -> Array: ...
|
|
def __mod__(self, other) -> Array: ...
|
|
def __divmod__(self, other) -> Array: ...
|
|
def __pow__(self, other) -> Array: ...
|
|
def __lshift__(self, other) -> Array: ...
|
|
def __rshift__(self, other) -> Array: ...
|
|
def __and__(self, other) -> Array: ...
|
|
def __xor__(self, other) -> Array: ...
|
|
def __or__(self, other) -> Array: ...
|
|
|
|
def __radd__(self, other) -> Array: ...
|
|
def __rsub__(self, other) -> Array: ...
|
|
def __rmul__(self, other) -> Array: ...
|
|
def __rmatmul__(self, other) -> Array: ...
|
|
def __rtruediv__(self, other) -> Array: ...
|
|
def __rfloordiv__(self, other) -> Array: ...
|
|
def __rmod__(self, other) -> Array: ...
|
|
def __rdivmod__(self, other) -> Array: ...
|
|
def __rpow__(self, other) -> Array: ...
|
|
def __rlshift__(self, other) -> Array: ...
|
|
def __rrshift__(self, other) -> Array: ...
|
|
def __rand__(self, other) -> Array: ...
|
|
def __rxor__(self, other) -> Array: ...
|
|
def __ror__(self, other) -> Array: ...
|
|
|
|
def __bool__(self) -> bool: ...
|
|
def __complex__(self) -> complex: ...
|
|
def __int__(self) -> int: ...
|
|
def __float__(self) -> float: ...
|
|
def __index__(self) -> int: ...
|
|
|
|
# np.ndarray methods:
|
|
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
|
keepdims=None) -> Array: ...
|
|
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
|
keepdims=None) -> Array: ...
|
|
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
|
|
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
|
|
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
|
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
|
def astype(self, dtype) -> Array: ...
|
|
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
|
def clip(self, min=None, max=None, out=None) -> Array: ...
|
|
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
|
|
def conj(self) -> Array: ...
|
|
def conjugate(self) -> Array: ...
|
|
def copy(self) -> Array: ...
|
|
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
dtype=None, out=None) -> Array: ...
|
|
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
dtype=None, out=None) -> Array: ...
|
|
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
|
|
def dot(self, b, *, precision=None) -> Array: ...
|
|
def flatten(self) -> Array: ...
|
|
@property
|
|
def imag(self) -> Array: ...
|
|
def item(self, *args) -> Any: ...
|
|
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
|
keepdims=None, initial=None, where=None) -> Array: ...
|
|
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
|
out=None, keepdims=False, *, where=None,) -> Array: ...
|
|
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
|
keepdims=None, initial=None, where=None) -> Array: ...
|
|
@property
|
|
def nbytes(self) -> int: ...
|
|
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
|
|
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
|
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
|
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
|
keepdims=False,) -> Array: ...
|
|
def ravel(self, order='C') -> Array: ...
|
|
@property
|
|
def real(self) -> Array: ...
|
|
def repeat(self, repeats, axis: Optional[int] = None, *,
|
|
total_repeat_length=None) -> Array: ...
|
|
def reshape(self, *args, order='C') -> Array: ...
|
|
def round(self, decimals=0, out=None) -> Array: ...
|
|
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
|
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
|
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
|
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
|
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
|
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
|
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
|
|
def take(self, indices, axis: Optional[int] = None, out=None,
|
|
mode=None) -> Array: ...
|
|
def tobytes(self, order='C') -> bytes: ...
|
|
def tolist(self) -> List[Any]: ...
|
|
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
|
|
out=None) -> Array: ...
|
|
def transpose(self, *args) -> Array: ...
|
|
@property
|
|
def T(self) -> Array: ...
|
|
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
|
def view(self, dtype=None, type=None) -> Array: ...
|
|
|
|
# Even though we don't always support the NumPy array protocol, e.g., for
|
|
# tracer types, for type checking purposes we must declare support so we
|
|
# implement the NumPy ArrayLike protocol.
|
|
def __array__(self) -> np.ndarray: ...
|
|
def __dlpack__(self) -> Any: ...
|
|
|
|
# JAX extensions
|
|
@property
|
|
def at(self) -> _IndexUpdateHelper: ...
|
|
@property
|
|
def weak_type(self) -> bool: ...
|
|
|
|
# Methods defined on ArrayImpl, but not on Tracers
|
|
def addressable_data(self, index: int) -> Array: ...
|
|
def block_until_ready(self) -> Array: ...
|
|
def copy_to_host_async(self) -> None: ...
|
|
def delete(self) -> None: ...
|
|
def device(self) -> Device: ...
|
|
def devices(self) -> Set[Device]: ...
|
|
@property
|
|
def global_shards(self) -> Sequence[Shard]: ...
|
|
def is_deleted(self) -> bool: ...
|
|
@property
|
|
def is_fully_addressable(self) -> bool: ...
|
|
@property
|
|
def is_fully_replicated(self) -> bool: ...
|
|
def on_device_size_in_bytes(self) -> int: ...
|
|
@property
|
|
def traceback(self) -> Traceback: ...
|
|
def unsafe_buffer_pointer(self) -> int: ...
|
|
@property
|
|
def device_buffers(self) -> Any: ...
|
|
|
|
|
|
ArrayLike = Union[
|
|
Array, # JAX array type
|
|
np.ndarray, # NumPy array type
|
|
np.bool_, np.number, # NumPy scalar types
|
|
bool, int, float, complex, # Python scalar types
|
|
]
|
|
|
|
|
|
# TODO: restructure to avoid re-defining this here?
|
|
# from jax._src.numpy.lax_numpy import _IndexUpdateHelper
|
|
|
|
class _IndexUpdateHelper:
|
|
def __getitem__(self, index: Any) -> _IndexUpdateRef: ...
|
|
|
|
class _IndexUpdateRef:
|
|
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
|
|
mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ...
|
|
def set(self, values: Any,
|
|
indices_are_sorted: bool = False, unique_indices: bool = False,
|
|
mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ...
|
|
def add(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def mul(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def multiply(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def divide(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def power(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def min(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def max(self, values: Any, indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|
|
def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False,
|
|
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
|