mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add initial jax.Array base class for instance checks & annotation
This commit is contained in:
parent
2180710c2a
commit
0cb233eec9
@ -18,6 +18,7 @@ repos:
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: (jax/|tests/typing_test\.py)
|
||||
exclude: jax/_src/basearray.py # Use pyi instead
|
||||
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
|
@ -15,6 +15,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
{jax-issue}`#7733`) is stable and public. See [the
|
||||
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
|
||||
for {mod}`jax.stages`.
|
||||
* Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks
|
||||
and type annotations for array types in JAX. Notice that this included some subtle
|
||||
changes to how `isinstance` works for {class}`jax.numpy.ndarray` for jax-internal
|
||||
objects, as {class}`jax.numpy.ndarray` is now a simple alias of {class}`jax.Array`.
|
||||
* Breaking changes
|
||||
* `jax._src` is no longer imported into the from the public `jax` namespace.
|
||||
This may break users that were using JAX internals.
|
||||
|
@ -35,6 +35,8 @@ del _cloud_tpu_init
|
||||
from jax import config as _config_module
|
||||
del _config_module
|
||||
|
||||
from jax._src.basearray import Array as Array
|
||||
|
||||
from jax._src.config import (
|
||||
config as config,
|
||||
enable_checks as enable_checks,
|
||||
|
@ -612,8 +612,7 @@ def _cpp_jit(
|
||||
# to know whether `jax.jit(f)(x)` will execute or trace, it's not enough to
|
||||
# inspect the argument x, we actually do need to execute it and look at the
|
||||
# outputs that could be tracers (if f is capturing `Tracer` by closure).
|
||||
execute: Optional[functools.partial] = (
|
||||
dispatch.xla_callable.most_recent_entry())
|
||||
execute = dispatch.xla_callable.most_recent_entry()
|
||||
|
||||
fastpath_data = None
|
||||
|
||||
|
53
jax/_src/basearray.py
Normal file
53
jax/_src/basearray.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note that type annotations for this file are defined in basearray.pyi
|
||||
|
||||
import abc
|
||||
|
||||
class Array(abc.ABC):
|
||||
"""Experimental Array base class for JAX
|
||||
|
||||
`jax.Array` is meant as the future public interface for instance checks and
|
||||
type annotation of JAX array objects. JAX Array object types are currently in
|
||||
flux, and this class only fully supports the new `jax.experimental.Array`, which
|
||||
will soon replace the old-style {class}`DeviceArray`, {class}`ShardedDeviceArray`,
|
||||
{class}`GlobalDeviceArray`, etc.
|
||||
|
||||
The compatibility is summarized in the following table:
|
||||
|
||||
================================ ====================== =========================
|
||||
object type ``isinstance`` support type annotation support
|
||||
================================ ====================== =========================
|
||||
{class}`DeviceArray` ✅ ❌
|
||||
{class}`ShardedDeviceArray` ✅ ❌
|
||||
{class}`GlobalDeviceArray` ✅ ❌
|
||||
{class}`~jax.core.Tracer` ✅ ✅
|
||||
{class}`~jax.experimental.Array` ✅ ✅
|
||||
================================ ====================== =========================
|
||||
|
||||
In other words, ``isinstance(x, jax.Array)`` will return True for any of these types,
|
||||
whereas annotations such as ``x : jax.Array`` will only type-check correctly for
|
||||
instances of {class}`~jax.core.Tracer` and {class}`jax.experimental.Array`, and not
|
||||
for the other soon-to-be-deprecated array types.
|
||||
"""
|
||||
# Note: no abstract methods are defined in this base class; the associated pyi
|
||||
# file contains the type signature for static type checking.
|
||||
|
||||
__slots__ = ['__weakref__']
|
||||
|
||||
# at property must be defined because we overwrite its docstring in lax_numpy.py
|
||||
@property
|
||||
def at(self):
|
||||
raise NotImplementedError("property must be defined in subclasses")
|
172
jax/_src/basearray.pyi
Normal file
172
jax/_src/basearray.pyi
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Array(abc.ABC):
|
||||
dtype: np.dtype
|
||||
ndim: int
|
||||
size: int
|
||||
aval: Any
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: ...
|
||||
|
||||
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
|
||||
order=None):
|
||||
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
|
||||
" Use jax.numpy.array, or jax.numpy.zeros instead.")
|
||||
|
||||
def __getitem__(self, key, indices_are_sorted=False,
|
||||
unique_indices=False) -> Array: ...
|
||||
def __setitem__(self, key, value) -> None: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> Any: ...
|
||||
def __reversed__(self) -> Any: ...
|
||||
def __round__(self, ndigits=None) -> Array: ...
|
||||
|
||||
# Comparisons
|
||||
|
||||
# these return bool for object, so ignore override errors.
|
||||
def __lt__(self, other) -> Array: ... # type: ignore[override]
|
||||
def __le__(self, other) -> Array: ... # type: ignore[override]
|
||||
def __eq__(self, other) -> Array: ... # type: ignore[override]
|
||||
def __ne__(self, other) -> Array: ... # type: ignore[override]
|
||||
def __gt__(self, other) -> Array: ... # type: ignore[override]
|
||||
def __ge__(self, other) -> Array: ... # type: ignore[override]
|
||||
|
||||
# Unary arithmetic
|
||||
|
||||
def __neg__(self) -> Array: ...
|
||||
def __pos__(self) -> Array: ...
|
||||
def __abs__(self) -> Array: ...
|
||||
def __invert__(self) -> Array: ...
|
||||
|
||||
# Binary arithmetic
|
||||
|
||||
def __add__(self, other) -> Array: ...
|
||||
def __sub__(self, other) -> Array: ...
|
||||
def __mul__(self, other) -> Array: ...
|
||||
def __matmul__(self, other) -> Array: ...
|
||||
def __truediv__(self, other) -> Array: ...
|
||||
def __floordiv__(self, other) -> Array: ...
|
||||
def __mod__(self, other) -> Array: ...
|
||||
def __divmod__(self, other) -> Array: ...
|
||||
def __pow__(self, other) -> Array: ...
|
||||
def __lshift__(self, other) -> Array: ...
|
||||
def __rshift__(self, other) -> Array: ...
|
||||
def __and__(self, other) -> Array: ...
|
||||
def __xor__(self, other) -> Array: ...
|
||||
def __or__(self, other) -> Array: ...
|
||||
|
||||
def __radd__(self, other) -> Array: ...
|
||||
def __rsub__(self, other) -> Array: ...
|
||||
def __rmul__(self, other) -> Array: ...
|
||||
def __rmatmul__(self, other) -> Array: ...
|
||||
def __rtruediv__(self, other) -> Array: ...
|
||||
def __rfloordiv__(self, other) -> Array: ...
|
||||
def __rmod__(self, other) -> Array: ...
|
||||
def __rdivmod__(self, other) -> Array: ...
|
||||
def __rpow__(self, other) -> Array: ...
|
||||
def __rlshift__(self, other) -> Array: ...
|
||||
def __rrshift__(self, other) -> Array: ...
|
||||
def __rand__(self, other) -> Array: ...
|
||||
def __rxor__(self, other) -> Array: ...
|
||||
def __ror__(self, other) -> Array: ...
|
||||
|
||||
def __bool__(self) -> bool: ...
|
||||
def __complex__(self) -> complex: ...
|
||||
def __int__(self) -> int: ...
|
||||
def __float__(self) -> float: ...
|
||||
def __index__(self) -> int: ...
|
||||
|
||||
# np.ndarray methods:
|
||||
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Array: ...
|
||||
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Array: ...
|
||||
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
|
||||
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
||||
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def astype(self, dtype) -> Array: ...
|
||||
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
||||
def clip(self, min=None, max=None, out=None) -> Array: ...
|
||||
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
|
||||
def conj(self) -> Array: ...
|
||||
def conjugate(self) -> Array: ...
|
||||
def copy(self) -> Array: ...
|
||||
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Array: ...
|
||||
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Array: ...
|
||||
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
|
||||
def dot(self, b, *, precision=None) -> Array: ...
|
||||
def flatten(self) -> Array: ...
|
||||
@property
|
||||
def imag(self) -> Array: ...
|
||||
def item(self, *args) -> Any: ...
|
||||
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None,) -> Array: ...
|
||||
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Array: ...
|
||||
@property
|
||||
def nbytes(self) -> int: ...
|
||||
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
|
||||
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False,) -> Array: ...
|
||||
def ravel(self, order='C') -> Array: ...
|
||||
@property
|
||||
def real(self) -> Array: ...
|
||||
def repeat(self, repeats, axis: Optional[int] = None, *,
|
||||
total_repeat_length=None) -> Array: ...
|
||||
def reshape(self, *args, order='C') -> Array: ...
|
||||
def round(self, decimals=0, out=None) -> Array: ...
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
|
||||
def take(self, indices, axis: Optional[int] = None, out=None,
|
||||
mode=None) -> Array: ...
|
||||
def tobytes(self, order='C') -> bytes: ...
|
||||
def tolist(self) -> List[Any]: ...
|
||||
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
|
||||
out=None) -> Array: ...
|
||||
def transpose(self, *args) -> Array: ...
|
||||
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
def view(self, dtype=None, type=None) -> Array: ...
|
||||
|
||||
# Even though we don't always support the NumPy array protocol, e.g., for
|
||||
# tracer types, for type checking purposes we must declare support so we
|
||||
# implement the NumPy ArrayLike protocol.
|
||||
def __array__(self) -> np.ndarray: ...
|
||||
def __dlpack__(self) -> Any: ...
|
||||
|
||||
# JAX extensions
|
||||
@property
|
||||
def at(self) -> Any: ...
|
||||
@property
|
||||
def weak_type(self) -> bool: ...
|
@ -37,6 +37,7 @@ from jax._src import source_info_util, traceback_util
|
||||
from jax._src.lax import control_flow as cf
|
||||
from jax._src.config import config
|
||||
from jax import lax
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
|
||||
safe_zip)
|
||||
|
||||
@ -62,9 +63,9 @@ def setnewattr(obj, name, val):
|
||||
|
||||
## Error value data type and functional assert.
|
||||
|
||||
Bool = Union[bool, core.Tracer]
|
||||
Int = Union[int, core.Tracer]
|
||||
Payload = Union[np.ndarray, jnp.ndarray, core.Tracer]
|
||||
Bool = Union[bool, Array]
|
||||
Int = Union[int, Array]
|
||||
Payload = Union[np.ndarray, Array]
|
||||
|
||||
# For now, the payload needs to be a fixed-size array: 3 int32s, used for the
|
||||
# OOB message.
|
||||
|
@ -30,6 +30,7 @@ from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax._src.util as util
|
||||
from jax._src.typing import Array
|
||||
|
||||
### device-persistent data
|
||||
|
||||
@ -332,7 +333,9 @@ class DeletedBuffer(object): pass
|
||||
deleted_buffer = DeletedBuffer()
|
||||
|
||||
|
||||
Array.register(DeviceArray)
|
||||
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
|
||||
for _device_array in device_array_types:
|
||||
core.literalable_types.add(_device_array)
|
||||
core.pytype_aval_mappings[_device_array] = abstract_arrays.canonical_concrete_aval
|
||||
Array.register(_device_array)
|
||||
|
@ -3601,12 +3601,12 @@ def take_along_axis(arr, indices, axis: Optional[int],
|
||||
j += 1
|
||||
|
||||
|
||||
gather_indices = lax.concatenate(gather_indices, dimension=j)
|
||||
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=tuple(offset_dims),
|
||||
collapsed_slice_dims=tuple(collapsed_slice_dims),
|
||||
start_index_map=tuple(start_index_map))
|
||||
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes),
|
||||
return lax.gather(arr, gather_indices_arr, dnums, tuple(slice_sizes),
|
||||
mode="fill" if mode is None else mode)
|
||||
|
||||
### Indexing
|
||||
|
@ -12,282 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Any, Tuple, Optional, Union
|
||||
__all__ = ['ndarray']
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import device_array
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ArrayMeta(abc.ABCMeta):
|
||||
"""Metaclass for overriding ndarray isinstance checks."""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
# Allow tracer instances with avals that are instances of UnshapedArray.
|
||||
# We could instead just declare Tracer an instance of the ndarray type, but
|
||||
# there can be traced values that are not arrays. The main downside here is
|
||||
# that isinstance(x, ndarray) might return true but
|
||||
# issubclass(type(x), ndarray) might return false for an array tracer.
|
||||
try:
|
||||
return isinstance(instance.aval, core.UnshapedArray)
|
||||
except AttributeError:
|
||||
super().__instancecheck__(instance)
|
||||
|
||||
|
||||
class ndarray(metaclass=ArrayMeta):
|
||||
dtype: np.dtype
|
||||
ndim: int
|
||||
shape: Tuple[int, ...]
|
||||
size: int
|
||||
|
||||
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
|
||||
order=None):
|
||||
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
|
||||
" Use jax.numpy.array, or jax.numpy.zeros instead.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, key, indices_are_sorted=False,
|
||||
unique_indices=False) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __setitem__(self, key, value) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __reversed__(self) -> Any: ...
|
||||
|
||||
# Comparisons
|
||||
@abc.abstractmethod
|
||||
def __lt__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __le__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __ne__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __gt__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __ge__(self, other) -> Any: ...
|
||||
|
||||
# Unary arithmetic
|
||||
|
||||
@abc.abstractmethod
|
||||
def __neg__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __pos__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __abs__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __invert__(self) -> Any: ...
|
||||
|
||||
# Binary arithmetic
|
||||
|
||||
@abc.abstractmethod
|
||||
def __add__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __sub__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __mul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __matmul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __truediv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __floordiv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __mod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __divmod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __pow__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __lshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __and__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __xor__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __or__(self, other) -> Any: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __radd__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rsub__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rmul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rmatmul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rtruediv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rfloordiv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rmod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rdivmod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rpow__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rlshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rrshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rand__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rxor__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __ror__(self, other) -> Any: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __bool__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __complex__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __int__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __float__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __round__(self, ndigits=None) -> Any: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __index__(self) -> Any: ...
|
||||
|
||||
# np.ndarray methods:
|
||||
@abc.abstractmethod
|
||||
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def astype(self, dtype) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def choose(self, choices, out=None, mode='raise') -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def clip(self, min=None, max=None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def conj(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def conjugate(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def copy(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def dot(self, b, *, precision=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def flatten(self) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def imag(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def item(self, *args) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None,) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def nbytes(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def nonzero(self, *, size=None, fill_value=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False,) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def ravel(self, order='C') -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def real(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def repeat(self, repeats, axis: Optional[int] = None, *,
|
||||
total_repeat_length=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def reshape(self, *args, order='C') -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def round(self, decimals=0, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def swapaxes(self, axis1: int, axis2: int) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def take(self, indices, axis: Optional[int] = None, out=None,
|
||||
mode=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def tobytes(self, order='C') -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def tolist(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
|
||||
out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def transpose(self, *args) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def view(self, dtype=None, type=None) -> Any: ...
|
||||
|
||||
# Even though we don't always support the NumPy array protocol, e.g., for
|
||||
# tracer types, for type checking purposes we must declare support so we
|
||||
# implement the NumPy ArrayLike protocol.
|
||||
def __array__(self) -> Any: ...
|
||||
|
||||
def __dlpack__(self) -> Any: ...
|
||||
|
||||
# JAX extensions
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def at(self) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def aval(self) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def weak_type(self) -> bool: ...
|
||||
|
||||
|
||||
ndarray.register(device_array.DeviceArray)
|
||||
for t in device_array.device_array_types:
|
||||
ndarray.register(t)
|
||||
ndarray.register(pxla._SDA_BASE_CLASS)
|
||||
from jax._src.typing import Array as ndarray
|
||||
|
@ -578,7 +578,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
|
||||
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
|
||||
hi = np.array(1., dtype)
|
||||
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
|
||||
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
|
||||
return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u))
|
||||
|
||||
|
||||
def multivariate_normal(key: KeyArray,
|
||||
|
@ -30,6 +30,9 @@ from typing import Any, Sequence, Union
|
||||
from typing_extensions import Protocol
|
||||
import numpy as np
|
||||
|
||||
from jax._src.basearray import Array
|
||||
|
||||
|
||||
class HasDTypeAttribute(Protocol):
|
||||
dtype: DType
|
||||
|
||||
@ -51,11 +54,7 @@ Shape = Sequence[DimSize]
|
||||
|
||||
# Array is a type annotation for standard JAX arrays and tracers produced by
|
||||
# core functions in jax.lax and jax.numpy; it is not meant to include
|
||||
# future non-standard array types like KeyArray and BInt.
|
||||
# For now we set it to Any; in the future this will be more restrictive
|
||||
# (see https://github.com/google/jax/pull/11859)
|
||||
# TODO(jakevdp): make this conform to the JEP 12049 plan.
|
||||
Array = Any
|
||||
# future non-standard array types like KeyArray and BInt. It is imported above.
|
||||
|
||||
# ArrayLike is a Union of all objects that can be implicitly converted to a standard
|
||||
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
|
||||
|
10
jax/core.py
10
jax/core.py
@ -50,6 +50,7 @@ from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src import traceback_util
|
||||
from jax._src.typing import DimSize, Shape
|
||||
from jax._src import typing
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -530,9 +531,10 @@ def escaped_tracer_error(tracer, detail=None):
|
||||
msg += f'Detail: {detail}'
|
||||
return UnexpectedTracerError(msg)
|
||||
|
||||
class Tracer:
|
||||
|
||||
class Tracer(typing.Array):
|
||||
__array_priority__ = 1000
|
||||
__slots__ = ['_trace', '__weakref__', '_line_info']
|
||||
__slots__ = ['_trace', '_line_info']
|
||||
|
||||
def __array__(self, *args, **kw):
|
||||
raise TracerArrayConversionError(self)
|
||||
@ -556,6 +558,10 @@ class Tracer:
|
||||
def __len__(self):
|
||||
return self.aval._len(self)
|
||||
|
||||
@property
|
||||
def at(self):
|
||||
return self.aval.at.fget(self)
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
raise NotImplementedError("must override")
|
||||
|
@ -22,6 +22,7 @@ from jax import core
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -29,7 +30,7 @@ from jax._src.config import config
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.api import device_put
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax.interpreters import pxla, xla, mlir
|
||||
from jax.experimental.sharding import (
|
||||
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
||||
@ -39,7 +40,6 @@ Shape = Tuple[int, ...]
|
||||
Device = xc.Device
|
||||
DeviceArray = xc.Buffer
|
||||
Index = Tuple[slice, ...]
|
||||
ArrayLike = Union[np.ndarray, DeviceArray]
|
||||
|
||||
|
||||
class Shard:
|
||||
@ -101,7 +101,7 @@ def _single_device_array_from_buf(buf, committed):
|
||||
|
||||
|
||||
@pxla.use_cpp_class(xc.Array if xc._version >= 92 else None)
|
||||
class Array:
|
||||
class Array(basearray.Array):
|
||||
"""Experimental unified Array type.
|
||||
|
||||
This Python implementation will eventually be replaced by a C++ implementation.
|
||||
@ -498,7 +498,9 @@ xla.canonicalize_dtype_handlers[Array] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[Array] = op.attrgetter('aval')
|
||||
ad_util.jaxval_adders[Array] = lax_internal.add
|
||||
ad_util.jaxval_zeros_likers[Array] = lax_internal.zeros_like_array
|
||||
ndarray.register(Array)
|
||||
if xc._version >= 92:
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
basearray.Array.register(Array)
|
||||
|
||||
|
||||
def _array_mlir_constant_handler(val, canonicalize_types=True):
|
||||
|
@ -62,6 +62,7 @@ from jax.tree_util import tree_flatten, tree_map
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
@ -689,6 +690,7 @@ if _USE_CPP_SDA:
|
||||
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
||||
else:
|
||||
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
|
||||
basearray.Array.register(_SDA_BASE_CLASS)
|
||||
|
||||
|
||||
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
|
||||
|
@ -41,6 +41,7 @@ from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
|
||||
# TODO: update callers to refer to new location.
|
||||
from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401
|
||||
from jax._src.util import wrap_name as wrap_name # noqa: F401
|
||||
from jax._src.typing import Shape
|
||||
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -134,9 +135,9 @@ def parameter(builder, num, shape, name=None, replicated=None):
|
||||
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
||||
# checkers don't support recursive types), so we only represent one level of
|
||||
# nesting in this type definition.
|
||||
SpatialSharding = Union[Tuple[int, ...],
|
||||
SpatialSharding = Union[Shape,
|
||||
None,
|
||||
Tuple[Optional[Tuple[int, ...]], ...]]
|
||||
Tuple[Optional[Shape], ...]]
|
||||
|
||||
def sharding_to_proto(sharding: SpatialSharding):
|
||||
"""Converts a SpatialSharding to an OpSharding.
|
||||
|
@ -17,14 +17,18 @@ Typing tests
|
||||
This test is meant to be both a runtime test and a static type annotation test,
|
||||
so it should be checked with pytype/mypy as well as being run with pytest.
|
||||
"""
|
||||
from typing import Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import typing
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.array import Array as ArrayImpl
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
@ -98,9 +102,6 @@ class TypingTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out9, jnp.float32(0))
|
||||
|
||||
def testArrayInstanceChecks(self):
|
||||
# TODO(jakevdp): enable this test when `typing.Array` instance checks are implemented.
|
||||
self.skipTest("Test is broken for now.")
|
||||
|
||||
def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]:
|
||||
return isinstance(x, typing.Array)
|
||||
|
||||
@ -112,6 +113,21 @@ class TypingTest(jtu.JaxTestCase):
|
||||
self.assertTrue(jax.jit(is_array)(x))
|
||||
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
|
||||
|
||||
def testAnnotations(self):
|
||||
# This test is mainly meant for static type checking: we want to ensure that
|
||||
# Tracer and ArrayImpl are valid as array.Array.
|
||||
with jax_config.jax_array(True):
|
||||
def f(x: Any) -> Optional[typing.Array]:
|
||||
if isinstance(x, core.Tracer):
|
||||
return x
|
||||
elif isinstance(x, ArrayImpl):
|
||||
return x
|
||||
else:
|
||||
return None
|
||||
|
||||
x = jnp.arange(10)
|
||||
y = f(x)
|
||||
self.assertArraysEqual(x, y)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user