Merge pull request #12421 from jakevdp:jax-array

PiperOrigin-RevId: 476898184
This commit is contained in:
jax authors 2022-09-26 08:07:11 -07:00
commit 2df61b1aa1
17 changed files with 288 additions and 303 deletions

View File

@ -18,6 +18,7 @@ repos:
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py # Use pyi instead
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]
- repo: https://github.com/mwouts/jupytext

View File

@ -15,6 +15,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
{jax-issue}`#7733`) is stable and public. See [the
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
for {mod}`jax.stages`.
* Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks
and type annotations for array types in JAX. Notice that this included some subtle
changes to how `isinstance` works for {class}`jax.numpy.ndarray` for jax-internal
objects, as {class}`jax.numpy.ndarray` is now a simple alias of {class}`jax.Array`.
* Breaking changes
* `jax._src` is no longer imported into the from the public `jax` namespace.
This may break users that were using JAX internals.

View File

@ -35,6 +35,8 @@ del _cloud_tpu_init
from jax import config as _config_module
del _config_module
from jax._src.basearray import Array as Array
from jax._src.config import (
config as config,
enable_checks as enable_checks,

View File

@ -612,8 +612,7 @@ def _cpp_jit(
# to know whether `jax.jit(f)(x)` will execute or trace, it's not enough to
# inspect the argument x, we actually do need to execute it and look at the
# outputs that could be tracers (if f is capturing `Tracer` by closure).
execute: Optional[functools.partial] = (
dispatch.xla_callable.most_recent_entry())
execute = dispatch.xla_callable.most_recent_entry()
fastpath_data = None

53
jax/_src/basearray.py Normal file
View File

@ -0,0 +1,53 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note that type annotations for this file are defined in basearray.pyi
import abc
class Array(abc.ABC):
"""Experimental Array base class for JAX
`jax.Array` is meant as the future public interface for instance checks and
type annotation of JAX array objects. JAX Array object types are currently in
flux, and this class only fully supports the new `jax.experimental.Array`, which
will soon replace the old-style {class}`DeviceArray`, {class}`ShardedDeviceArray`,
{class}`GlobalDeviceArray`, etc.
The compatibility is summarized in the following table:
================================ ====================== =========================
object type ``isinstance`` support type annotation support
================================ ====================== =========================
{class}`DeviceArray`
{class}`ShardedDeviceArray`
{class}`GlobalDeviceArray`
{class}`~jax.core.Tracer`
{class}`~jax.experimental.Array`
================================ ====================== =========================
In other words, ``isinstance(x, jax.Array)`` will return True for any of these types,
whereas annotations such as ``x : jax.Array`` will only type-check correctly for
instances of {class}`~jax.core.Tracer` and {class}`jax.experimental.Array`, and not
for the other soon-to-be-deprecated array types.
"""
# Note: no abstract methods are defined in this base class; the associated pyi
# file contains the type signature for static type checking.
__slots__ = ['__weakref__']
# at property must be defined because we overwrite its docstring in lax_numpy.py
@property
def at(self):
raise NotImplementedError("property must be defined in subclasses")

172
jax/_src/basearray.pyi Normal file
View File

@ -0,0 +1,172 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, List, Optional, Tuple, Union
import numpy as np
class Array(abc.ABC):
dtype: np.dtype
ndim: int
size: int
aval: Any
@property
def shape(self) -> Tuple[int, ...]: ...
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
def __getitem__(self, key, indices_are_sorted=False,
unique_indices=False) -> Array: ...
def __setitem__(self, key, value) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Any: ...
def __reversed__(self) -> Any: ...
def __round__(self, ndigits=None) -> Array: ...
# Comparisons
# these return bool for object, so ignore override errors.
def __lt__(self, other) -> Array: ... # type: ignore[override]
def __le__(self, other) -> Array: ... # type: ignore[override]
def __eq__(self, other) -> Array: ... # type: ignore[override]
def __ne__(self, other) -> Array: ... # type: ignore[override]
def __gt__(self, other) -> Array: ... # type: ignore[override]
def __ge__(self, other) -> Array: ... # type: ignore[override]
# Unary arithmetic
def __neg__(self) -> Array: ...
def __pos__(self) -> Array: ...
def __abs__(self) -> Array: ...
def __invert__(self) -> Array: ...
# Binary arithmetic
def __add__(self, other) -> Array: ...
def __sub__(self, other) -> Array: ...
def __mul__(self, other) -> Array: ...
def __matmul__(self, other) -> Array: ...
def __truediv__(self, other) -> Array: ...
def __floordiv__(self, other) -> Array: ...
def __mod__(self, other) -> Array: ...
def __divmod__(self, other) -> Array: ...
def __pow__(self, other) -> Array: ...
def __lshift__(self, other) -> Array: ...
def __rshift__(self, other) -> Array: ...
def __and__(self, other) -> Array: ...
def __xor__(self, other) -> Array: ...
def __or__(self, other) -> Array: ...
def __radd__(self, other) -> Array: ...
def __rsub__(self, other) -> Array: ...
def __rmul__(self, other) -> Array: ...
def __rmatmul__(self, other) -> Array: ...
def __rtruediv__(self, other) -> Array: ...
def __rfloordiv__(self, other) -> Array: ...
def __rmod__(self, other) -> Array: ...
def __rdivmod__(self, other) -> Array: ...
def __rpow__(self, other) -> Array: ...
def __rlshift__(self, other) -> Array: ...
def __rrshift__(self, other) -> Array: ...
def __rand__(self, other) -> Array: ...
def __rxor__(self, other) -> Array: ...
def __ror__(self, other) -> Array: ...
def __bool__(self) -> bool: ...
def __complex__(self) -> complex: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __index__(self) -> int: ...
# np.ndarray methods:
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Array: ...
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Array: ...
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def astype(self, dtype) -> Array: ...
def choose(self, choices, out=None, mode='raise') -> Array: ...
def clip(self, min=None, max=None, out=None) -> Array: ...
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
def conj(self) -> Array: ...
def conjugate(self) -> Array: ...
def copy(self) -> Array: ...
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Array: ...
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Array: ...
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b, *, precision=None) -> Array: ...
def flatten(self) -> Array: ...
@property
def imag(self) -> Array: ...
def item(self, *args) -> Any: ...
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None,) -> Array: ...
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
@property
def nbytes(self) -> int: ...
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False,) -> Array: ...
def ravel(self, order='C') -> Array: ...
@property
def real(self) -> Array: ...
def repeat(self, repeats, axis: Optional[int] = None, *,
total_repeat_length=None) -> Array: ...
def reshape(self, *args, order='C') -> Array: ...
def round(self, decimals=0, out=None) -> Array: ...
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
def take(self, indices, axis: Optional[int] = None, out=None,
mode=None) -> Array: ...
def tobytes(self, order='C') -> bytes: ...
def tolist(self) -> List[Any]: ...
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None) -> Array: ...
def transpose(self, *args) -> Array: ...
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def view(self, dtype=None, type=None) -> Array: ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self) -> np.ndarray: ...
def __dlpack__(self) -> Any: ...
# JAX extensions
@property
def at(self) -> Any: ...
@property
def weak_type(self) -> bool: ...

View File

@ -37,6 +37,7 @@ from jax._src import source_info_util, traceback_util
from jax._src.lax import control_flow as cf
from jax._src.config import config
from jax import lax
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
@ -62,9 +63,9 @@ def setnewattr(obj, name, val):
## Error value data type and functional assert.
Bool = Union[bool, core.Tracer]
Int = Union[int, core.Tracer]
Payload = Union[np.ndarray, jnp.ndarray, core.Tracer]
Bool = Union[bool, Array]
Int = Union[int, Array]
Payload = Union[np.ndarray, Array]
# For now, the payload needs to be a fixed-size array: 3 int32s, used for the
# OOB message.

View File

@ -30,6 +30,7 @@ from jax._src import dtypes
from jax._src import profiler
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src.typing import Array
### device-persistent data
@ -332,7 +333,9 @@ class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()
Array.register(DeviceArray)
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
for _device_array in device_array_types:
core.literalable_types.add(_device_array)
core.pytype_aval_mappings[_device_array] = abstract_arrays.canonical_concrete_aval
Array.register(_device_array)

View File

@ -3601,12 +3601,12 @@ def take_along_axis(arr, indices, axis: Optional[int],
j += 1
gather_indices = lax.concatenate(gather_indices, dimension=j)
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(offset_dims),
collapsed_slice_dims=tuple(collapsed_slice_dims),
start_index_map=tuple(start_index_map))
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes),
return lax.gather(arr, gather_indices_arr, dnums, tuple(slice_sizes),
mode="fill" if mode is None else mode)
### Indexing

View File

@ -12,282 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Tuple, Optional, Union
__all__ = ['ndarray']
from jax import core
from jax.interpreters import pxla
from jax._src import device_array
import numpy as np
class ArrayMeta(abc.ABCMeta):
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self, instance):
# Allow tracer instances with avals that are instances of UnshapedArray.
# We could instead just declare Tracer an instance of the ndarray type, but
# there can be traced values that are not arrays. The main downside here is
# that isinstance(x, ndarray) might return true but
# issubclass(type(x), ndarray) might return false for an array tracer.
try:
return isinstance(instance.aval, core.UnshapedArray)
except AttributeError:
super().__instancecheck__(instance)
class ndarray(metaclass=ArrayMeta):
dtype: np.dtype
ndim: int
shape: Tuple[int, ...]
size: int
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
@abc.abstractmethod
def __getitem__(self, key, indices_are_sorted=False,
unique_indices=False) -> Any: ...
@abc.abstractmethod
def __setitem__(self, key, value) -> Any: ...
@abc.abstractmethod
def __len__(self) -> Any: ...
@abc.abstractmethod
def __iter__(self) -> Any: ...
@abc.abstractmethod
def __reversed__(self) -> Any: ...
# Comparisons
@abc.abstractmethod
def __lt__(self, other) -> Any: ...
@abc.abstractmethod
def __le__(self, other) -> Any: ...
@abc.abstractmethod
def __eq__(self, other) -> Any: ...
@abc.abstractmethod
def __ne__(self, other) -> Any: ...
@abc.abstractmethod
def __gt__(self, other) -> Any: ...
@abc.abstractmethod
def __ge__(self, other) -> Any: ...
# Unary arithmetic
@abc.abstractmethod
def __neg__(self) -> Any: ...
@abc.abstractmethod
def __pos__(self) -> Any: ...
@abc.abstractmethod
def __abs__(self) -> Any: ...
@abc.abstractmethod
def __invert__(self) -> Any: ...
# Binary arithmetic
@abc.abstractmethod
def __add__(self, other) -> Any: ...
@abc.abstractmethod
def __sub__(self, other) -> Any: ...
@abc.abstractmethod
def __mul__(self, other) -> Any: ...
@abc.abstractmethod
def __matmul__(self, other) -> Any: ...
@abc.abstractmethod
def __truediv__(self, other) -> Any: ...
@abc.abstractmethod
def __floordiv__(self, other) -> Any: ...
@abc.abstractmethod
def __mod__(self, other) -> Any: ...
@abc.abstractmethod
def __divmod__(self, other) -> Any: ...
@abc.abstractmethod
def __pow__(self, other) -> Any: ...
@abc.abstractmethod
def __lshift__(self, other) -> Any: ...
@abc.abstractmethod
def __rshift__(self, other) -> Any: ...
@abc.abstractmethod
def __and__(self, other) -> Any: ...
@abc.abstractmethod
def __xor__(self, other) -> Any: ...
@abc.abstractmethod
def __or__(self, other) -> Any: ...
@abc.abstractmethod
def __radd__(self, other) -> Any: ...
@abc.abstractmethod
def __rsub__(self, other) -> Any: ...
@abc.abstractmethod
def __rmul__(self, other) -> Any: ...
@abc.abstractmethod
def __rmatmul__(self, other) -> Any: ...
@abc.abstractmethod
def __rtruediv__(self, other) -> Any: ...
@abc.abstractmethod
def __rfloordiv__(self, other) -> Any: ...
@abc.abstractmethod
def __rmod__(self, other) -> Any: ...
@abc.abstractmethod
def __rdivmod__(self, other) -> Any: ...
@abc.abstractmethod
def __rpow__(self, other) -> Any: ...
@abc.abstractmethod
def __rlshift__(self, other) -> Any: ...
@abc.abstractmethod
def __rrshift__(self, other) -> Any: ...
@abc.abstractmethod
def __rand__(self, other) -> Any: ...
@abc.abstractmethod
def __rxor__(self, other) -> Any: ...
@abc.abstractmethod
def __ror__(self, other) -> Any: ...
@abc.abstractmethod
def __bool__(self) -> Any: ...
@abc.abstractmethod
def __complex__(self) -> Any: ...
@abc.abstractmethod
def __int__(self) -> Any: ...
@abc.abstractmethod
def __float__(self) -> Any: ...
@abc.abstractmethod
def __round__(self, ndigits=None) -> Any: ...
@abc.abstractmethod
def __index__(self) -> Any: ...
# np.ndarray methods:
@abc.abstractmethod
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Any: ...
@abc.abstractmethod
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Any: ...
@abc.abstractmethod
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
@abc.abstractmethod
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
@abc.abstractmethod
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ...
@abc.abstractmethod
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
@abc.abstractmethod
def astype(self, dtype) -> Any: ...
@abc.abstractmethod
def choose(self, choices, out=None, mode='raise') -> Any: ...
@abc.abstractmethod
def clip(self, min=None, max=None, out=None) -> Any: ...
@abc.abstractmethod
def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ...
@abc.abstractmethod
def conj(self) -> Any: ...
@abc.abstractmethod
def conjugate(self) -> Any: ...
@abc.abstractmethod
def copy(self) -> Any: ...
@abc.abstractmethod
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Any: ...
@abc.abstractmethod
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Any: ...
@abc.abstractmethod
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ...
@abc.abstractmethod
def dot(self, b, *, precision=None) -> Any: ...
@abc.abstractmethod
def flatten(self) -> Any: ...
@property
@abc.abstractmethod
def imag(self) -> Any: ...
@abc.abstractmethod
def item(self, *args) -> Any: ...
@abc.abstractmethod
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Any: ...
@abc.abstractmethod
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None,) -> Any: ...
@abc.abstractmethod
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Any: ...
@property
@abc.abstractmethod
def nbytes(self) -> Any: ...
@abc.abstractmethod
def nonzero(self, *, size=None, fill_value=None) -> Any: ...
@abc.abstractmethod
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Any: ...
@abc.abstractmethod
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False,) -> Any: ...
@abc.abstractmethod
def ravel(self, order='C') -> Any: ...
@property
@abc.abstractmethod
def real(self) -> Any: ...
@abc.abstractmethod
def repeat(self, repeats, axis: Optional[int] = None, *,
total_repeat_length=None) -> Any: ...
@abc.abstractmethod
def reshape(self, *args, order='C') -> Any: ...
@abc.abstractmethod
def round(self, decimals=0, out=None) -> Any: ...
@abc.abstractmethod
def searchsorted(self, v, side='left', sorter=None) -> Any: ...
@abc.abstractmethod
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
@abc.abstractmethod
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ...
@abc.abstractmethod
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
@abc.abstractmethod
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Any: ...
@abc.abstractmethod
def swapaxes(self, axis1: int, axis2: int) -> Any: ...
@abc.abstractmethod
def take(self, indices, axis: Optional[int] = None, out=None,
mode=None) -> Any: ...
@abc.abstractmethod
def tobytes(self, order='C') -> Any: ...
@abc.abstractmethod
def tolist(self) -> Any: ...
@abc.abstractmethod
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None) -> Any: ...
@abc.abstractmethod
def transpose(self, *args) -> Any: ...
@abc.abstractmethod
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
@abc.abstractmethod
def view(self, dtype=None, type=None) -> Any: ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self) -> Any: ...
def __dlpack__(self) -> Any: ...
# JAX extensions
@property
@abc.abstractmethod
def at(self) -> Any: ...
@property
@abc.abstractmethod
def aval(self) -> Any: ...
@property
@abc.abstractmethod
def weak_type(self) -> bool: ...
ndarray.register(device_array.DeviceArray)
for t in device_array.device_array_types:
ndarray.register(t)
ndarray.register(pxla._SDA_BASE_CLASS)
from jax._src.typing import Array as ndarray

View File

@ -578,7 +578,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
hi = np.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u))
def multivariate_normal(key: KeyArray,

View File

@ -30,6 +30,9 @@ from typing import Any, Sequence, Union
from typing_extensions import Protocol
import numpy as np
from jax._src.basearray import Array
class HasDTypeAttribute(Protocol):
dtype: DType
@ -51,11 +54,7 @@ Shape = Sequence[DimSize]
# Array is a type annotation for standard JAX arrays and tracers produced by
# core functions in jax.lax and jax.numpy; it is not meant to include
# future non-standard array types like KeyArray and BInt.
# For now we set it to Any; in the future this will be more restrictive
# (see https://github.com/google/jax/pull/11859)
# TODO(jakevdp): make this conform to the JEP 12049 plan.
Array = Any
# future non-standard array types like KeyArray and BInt. It is imported above.
# ArrayLike is a Union of all objects that can be implicitly converted to a standard
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).

View File

@ -50,6 +50,7 @@ from jax._src import lib
from jax._src.lib import jax_jit
from jax._src import traceback_util
from jax._src.typing import DimSize, Shape
from jax._src import typing
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip
@ -530,9 +531,10 @@ def escaped_tracer_error(tracer, detail=None):
msg += f'Detail: {detail}'
return UnexpectedTracerError(msg)
class Tracer:
class Tracer(typing.Array):
__array_priority__ = 1000
__slots__ = ['_trace', '__weakref__', '_line_info']
__slots__ = ['_trace', '_line_info']
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
@ -556,6 +558,10 @@ class Tracer:
def __len__(self):
return self.aval._len(self)
@property
def at(self):
return self.aval.at.fget(self)
@property
def aval(self):
raise NotImplementedError("must override")

View File

@ -22,6 +22,7 @@ from jax import core
from jax._src import abstract_arrays
from jax._src import ad_util
from jax._src import api_util
from jax._src import basearray
from jax._src import dispatch
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
@ -29,7 +30,7 @@ from jax._src.config import config
from jax._src.util import prod, safe_zip
from jax._src.lib import xla_client as xc
from jax._src.api import device_put
from jax._src.numpy.ndarray import ndarray
from jax._src.typing import ArrayLike
from jax.interpreters import pxla, xla, mlir
from jax.experimental.sharding import (
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
@ -39,7 +40,6 @@ Shape = Tuple[int, ...]
Device = xc.Device
DeviceArray = xc.Buffer
Index = Tuple[slice, ...]
ArrayLike = Union[np.ndarray, DeviceArray]
class Shard:
@ -101,7 +101,7 @@ def _single_device_array_from_buf(buf, committed):
@pxla.use_cpp_class(xc.Array if xc._version >= 92 else None)
class Array:
class Array(basearray.Array):
"""Experimental unified Array type.
This Python implementation will eventually be replaced by a C++ implementation.
@ -498,7 +498,9 @@ xla.canonicalize_dtype_handlers[Array] = pxla.identity
api_util._shaped_abstractify_handlers[Array] = op.attrgetter('aval')
ad_util.jaxval_adders[Array] = lax_internal.add
ad_util.jaxval_zeros_likers[Array] = lax_internal.zeros_like_array
ndarray.register(Array)
if xc._version >= 92:
# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(Array)
def _array_mlir_constant_handler(val, canonicalize_types=True):

View File

@ -62,6 +62,7 @@ from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import basearray
from jax._src import device_array
from jax._src import dtypes
from jax._src import source_info_util
@ -689,6 +690,7 @@ if _USE_CPP_SDA:
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
else:
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
basearray.Array.register(_SDA_BASE_CLASS)
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore

View File

@ -41,6 +41,7 @@ from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
# TODO: update callers to refer to new location.
from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401
from jax._src.util import wrap_name as wrap_name # noqa: F401
from jax._src.typing import Shape
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
@ -134,9 +135,9 @@ def parameter(builder, num, shape, name=None, replicated=None):
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Tuple[int, ...],
SpatialSharding = Union[Shape,
None,
Tuple[Optional[Tuple[int, ...]], ...]]
Tuple[Optional[Shape], ...]]
def sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.

View File

@ -17,14 +17,18 @@ Typing tests
This test is meant to be both a runtime test and a static type annotation test,
so it should be checked with pytype/mypy as well as being run with pytest.
"""
from typing import Union
from typing import Any, Optional, Union
import jax
from jax import core
from jax._src import config as jax_config
from jax._src import test_util as jtu
from jax._src import typing
from jax import lax
import jax.numpy as jnp
from jax.experimental.array import Array as ArrayImpl
from absl.testing import absltest
import numpy as np
@ -98,9 +102,6 @@ class TypingTest(jtu.JaxTestCase):
self.assertArraysEqual(out9, jnp.float32(0))
def testArrayInstanceChecks(self):
# TODO(jakevdp): enable this test when `typing.Array` instance checks are implemented.
self.skipTest("Test is broken for now.")
def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]:
return isinstance(x, typing.Array)
@ -112,6 +113,21 @@ class TypingTest(jtu.JaxTestCase):
self.assertTrue(jax.jit(is_array)(x))
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
def testAnnotations(self):
# This test is mainly meant for static type checking: we want to ensure that
# Tracer and ArrayImpl are valid as array.Array.
with jax_config.jax_array(True):
def f(x: Any) -> Optional[typing.Array]:
if isinstance(x, core.Tracer):
return x
elif isinstance(x, ArrayImpl):
return x
else:
return None
x = jnp.arange(10)
y = f(x)
self.assertArraysEqual(x, y)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())