rocm_jax/jax/_src/numpy/lax_numpy.py
Peter Hawkins a4412e2715 Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00

5650 lines
217 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
NumPy operations are implemented in Python in terms of the primitive operations
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
implemented in terms of :mod:`jax.lax` operations, we do not need to define
transformation rules such as gradient or batching rules. Instead,
transformations for NumPy primitives can be derived from the transformation
rules for the underlying :code:`lax` primitives.
"""
import builtins
import collections
from functools import partial
import math
import operator
import types
from typing import (
overload, Any, Callable, Dict, FrozenSet, List, Literal,
Optional, Sequence, Tuple, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
import numpy as np
import opt_einsum
import jax
from jax import jit
from jax import errors
from jax import lax
from jax.interpreters import pxla
from jax.tree_util import tree_leaves, tree_flatten, tree_map
from jax._src import api_util
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.core import ShapedArray, DShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
from jax._src.lax import lax as lax_internal
from jax._src.lib import pmap_lib
from jax._src.lib import xla_client
from jax._src.numpy.reductions import ( # noqa: F401
_ensure_optional_axes, _reduction_dims,
alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct,
max, mean, min, nancumsum, nancumprod, nanmax, nanmean, nanmin, nanprod, nanstd,
nansum, nanvar, prod, product, ptp, sometrue, std, sum, var,
)
from jax._src.numpy.ufuncs import ( # noqa: F401
abs, absolute, add, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh,
bitwise_and, bitwise_not, bitwise_or, bitwise_xor, cbrt, ceil, conj, conjugate,
copysign, cos, cosh, deg2rad, degrees, divide, divmod, equal, exp, exp2, expm1,
fabs, float_power, floor, floor_divide, fmod, frexp, greater, greater_equal,
heaviside, hypot, imag, invert, isfinite, isinf, isnan, isneginf, isposinf,
ldexp, left_shift, less, less_equal, log, log10, log1p, log2, logaddexp, logaddexp2,
logical_and, logical_not, logical_or, logical_xor, maximum, minimum, mod, modf,
multiply, negative, nextafter, not_equal, positive, power, rad2deg, radians, real,
reciprocal, remainder, right_shift, rint, sign, signbit, sin, sinc, sinh, sqrt,
square, subtract, tan, tanh, true_divide)
from jax._src.numpy.util import ( # noqa: F401
_arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike,
_complex_elem_type, _promote_args, _promote_args_inexact, _promote_dtypes,
_promote_dtypes_numeric, _promote_dtypes_inexact, _promote_shapes,
_register_stackable, _stackable, _where, _wraps)
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl
newaxis = None
T = TypeVar('T')
# Like core.canonicalize_shape, but also accept int-like (non-sequence)
# arguments for `shape`.
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
if (not isinstance(shape, (tuple, list)) and
(getattr(shape, 'ndim', None) == 0 or ndim(shape) == 0)):
return core.canonicalize_shape((shape,), context) # type: ignore
else:
return core.canonicalize_shape(shape, context) # type: ignore
# Common docstring additions:
_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, a :class:`~jax.lax.Precision` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple
of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument.
"""
# Some objects below rewrite their __module__ attribute to this name.
_PUBLIC_MODULE_NAME = "jax.numpy"
# We replace some builtin names to follow Numpy's API, so we capture here.
_abs = builtins.abs
_all = builtins.all
_any = builtins.any
_max = builtins.max
_min = builtins.min
_sum = builtins.sum
_divmod = builtins.divmod
# NumPy constants
pi = np.pi
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
NINF = np.NINF
PZERO = np.PZERO
NZERO = np.NZERO
nan = np.nan
# NumPy utility functions
get_printoptions = np.get_printoptions
printoptions = np.printoptions
set_printoptions = np.set_printoptions
@_wraps(np.iscomplexobj)
def iscomplexobj(x: Any) -> bool:
try:
typ = x.dtype.type
except AttributeError:
typ = asarray(x).dtype.type
return issubdtype(typ, complexfloating)
shape = _shape = np.shape
ndim = _ndim = np.ndim
size = np.size
def _dtype(x: Any) -> DType:
return dtypes.dtype(x, canonicalize=True)
# At present JAX doesn't have a reason to distinguish between scalars and arrays
# in its object system. Further, we want JAX scalars to have the same type
# promotion behaviors as JAX arrays. Rather than introducing a new type of JAX
# scalar object with JAX promotion behaviors, instead we make the JAX scalar
# types return JAX arrays when instantiated.
class _ScalarMeta(type):
def __hash__(self) -> int:
return hash(self.dtype.type)
def __eq__(self, other: Any) -> bool:
return id(self) == id(other) or self.dtype.type == other
def __ne__(self, other: Any) -> bool:
return not (self == other)
def __call__(self, x: Any) -> Array:
return asarray(x, dtype=self.dtype)
def __instancecheck__(self, instance: Any) -> bool:
return isinstance(instance, self.dtype.type)
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
{"dtype": np.dtype(np_scalar_type)})
meta.__module__ = _PUBLIC_MODULE_NAME
return meta
bool_ = _make_scalar_type(np.bool_)
uint8 = _make_scalar_type(np.uint8)
uint16 = _make_scalar_type(np.uint16)
uint32 = _make_scalar_type(np.uint32)
uint64 = _make_scalar_type(np.uint64)
int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
if xla_client._version >= 117:
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)
float64 = double = _make_scalar_type(np.float64)
complex64 = csingle = _make_scalar_type(np.complex64)
complex128 = cdouble = _make_scalar_type(np.complex128)
int_ = int32 if dtypes.int_ == np.int32 else int64
uint = uint32 if dtypes.uint == np.uint32 else uint64
float_: Any = float32 if dtypes.float_ == np.float32 else float64
complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
generic = np.generic
number = np.number
inexact = np.inexact
complexfloating = np.complexfloating
floating = np.floating
integer = np.integer
signedinteger = np.signedinteger
unsignedinteger = np.unsignedinteger
flexible = np.flexible
character = np.character
object_ = np.object_
iinfo = dtypes.iinfo
finfo = dtypes.finfo
dtype = np.dtype
can_cast = dtypes.can_cast
issubsctype = dtypes.issubsctype
promote_types = dtypes.promote_types
ComplexWarning = np.ComplexWarning
array_str = np.array_str
array_repr = np.array_repr
save = np.save
savez = np.savez
@_wraps(np.dtype)
def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False,
copy: bool = False) -> DType:
"""Similar to np.dtype, but respects JAX dtype defaults."""
if obj is None:
obj = dtypes.float_
elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes:
obj = _DEFAULT_TYPEMAP[np.dtype(obj, align=align, copy=copy).type]
return np.dtype(obj, align=align, copy=copy)
### utility functions
_DEFAULT_TYPEMAP: Dict[type, _ScalarMeta] = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
np.complex_: complex_
}
_lax_const = lax_internal._const
def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
"""
Convert integer-typed val to specified integer dtype, clipping to dtype
range rather than wrapping.
Args:
val: value to be converted
dtype: dtype of output
Returns:
equivalent of val in new dtype
Examples
--------
Normal integer type conversion will wrap:
>>> val = jnp.uint32(0xFFFFFFFF)
>>> val.astype('int32')
Array(-1, dtype=int32)
This function clips to the values representable in the new type:
>>> _convert_and_clip_integer(val, 'int32')
Array(2147483647, dtype=int32)
"""
val = val if isinstance(val, Array) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
val_dtype = dtypes.canonicalize_dtype(val.dtype)
if val_dtype != val.dtype:
# TODO(jakevdp): this is a weird corner case; need to figure out how to handle it.
# This happens in X32 mode and can either come from a jax value created in another
# context, or a Python integer converted to int64.
pass
min_val = _lax_const(val, _max(iinfo(dtype).min, iinfo(val_dtype).min))
max_val = _lax_const(val, _min(iinfo(dtype).max, iinfo(val_dtype).max))
return clip(val, min_val, max_val).astype(dtype)
@_wraps(np.load, update_doc=False)
def load(*args: Any, **kwargs: Any) -> Array:
# The main purpose of this wrapper is to recover bfloat16 data types.
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(*args, **kwargs)
if isinstance(out, np.ndarray):
# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
out = out.view(bfloat16)
try:
out = asarray(out)
except TypeError: # Unsupported dtype
pass
return out
### implementations of numpy functions in terms of lax
@_wraps(np.fmin, module='numpy')
@jit
def fmin(x1: ArrayLike, x2: ArrayLike) -> Array:
return where(less(x1, x2) | isnan(x2), x1, x2)
@_wraps(np.fmax, module='numpy')
@jit
def fmax(x1: ArrayLike, x2: ArrayLike) -> Array:
return where(greater(x1, x2) | isnan(x2), x1, x2)
@_wraps(np.issubdtype)
def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
return dtypes.issubdtype(arg1, arg2)
@_wraps(np.isscalar)
def isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
iterable = np.iterable
@_wraps(np.result_type)
def result_type(*args: ArrayLike) -> DType:
return dtypes.result_type(*args)
@_wraps(np.trapz)
@partial(jit, static_argnames=('axis',))
def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array:
if x is None:
_check_arraylike('trapz', y)
y_arr, = _promote_dtypes_inexact(y)
else:
_check_arraylike('trapz', y, x)
y_arr, x_arr = _promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx = diff(x_arr)
else:
dx = moveaxis(diff(x_arr, axis=axis), axis, -1)
y_arr = moveaxis(y_arr, axis, -1)
return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
@_wraps(np.trunc, module='numpy')
@jit
def trunc(x: ArrayLike) -> Array:
_check_arraylike('trunc', x)
return where(lax.lt(x, _lax_const(x, 0)), ceil(x), floor(x))
@partial(jit, static_argnums=(2, 3, 4))
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array:
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = _promote_dtypes_inexact(x, y)
if len(x) == 0 or len(y) == 0:
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
out_order = slice(None)
if op == 'correlate':
y = conj(y)
if len(x) < len(y):
x, y = y, x
out_order = slice(None, None, -1)
elif op == 'convolve':
if len(x) < len(y):
x, y = y, x
y = flip(y)
if mode == 'valid':
padding = [(0, 0)]
elif mode == 'same':
padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)]
elif mode == 'full':
padding = [(y.shape[0] - 1, y.shape[0] - 1)]
else:
raise ValueError("mode must be one of ['full', 'same', 'valid']")
result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
padding, precision=precision)
return result[0, 0, out_order]
@_wraps(np.convolve, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
precision: PrecisionLike = None) -> Array:
_check_arraylike("convolve", a, v)
return _conv(asarray(a), asarray(v), mode, 'convolve', precision)
@_wraps(np.correlate, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
precision: PrecisionLike = None) -> Array:
_check_arraylike("correlate", a, v)
return _conv(asarray(a), asarray(v), mode, 'correlate', precision)
@_wraps(np.histogram_bin_edges)
def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
range: Union[None, Array, Sequence[ArrayLike]] = None,
weights: Optional[ArrayLike] = None) -> Array:
del weights # unused, because string bins is not supported.
if isinstance(bins, str):
raise NotImplementedError("string values for `bins` not implemented.")
_check_arraylike("histogram_bin_edges", a, bins)
arr = ravel(a)
dtype = dtypes.to_inexact_dtype(arr.dtype)
if _ndim(bins) == 1:
return asarray(bins, dtype=dtype)
bins_int = core.concrete_or_error(operator.index, bins,
"bins argument of histogram_bin_edges")
if range is None:
range = [arr.min(), arr.max()]
range = asarray(range, dtype=dtype)
if shape(range) != (2,):
raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}")
range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
where(ptp(range) == 0, range[1] + 0.5, range[1]))
assert range is not None
return linspace(range[0], range[1], bins_int + 1, dtype=dtype)
@_wraps(np.histogram)
def histogram(a: ArrayLike, bins: ArrayLike = 10,
range: Optional[Sequence[ArrayLike]] = None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, Array]:
if weights is None:
_check_arraylike("histogram", a, bins)
a = ravel(*_promote_dtypes_inexact(a))
weights = ones_like(a)
else:
_check_arraylike("histogram", a, bins, weights)
if shape(a) != shape(weights):
raise ValueError("weights should have the same shape as a.")
a, weights = map(ravel, _promote_dtypes_inexact(a, weights))
bin_edges = histogram_bin_edges(a, bins, range, weights)
bin_idx = searchsorted(bin_edges, a, side='right')
bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx)
counts = bincount(bin_idx, weights, length=len(bin_edges))[1:]
if density:
bin_widths = diff(bin_edges)
counts = counts / bin_widths / counts.sum()
return counts, bin_edges
@_wraps(np.histogram2d)
def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, Array, Array]:
_check_arraylike("histogram2d", x, y)
try:
N = len(bins) # type: ignore[arg-type]
except TypeError:
N = 1
if N != 1 and N != 2:
x_edges = y_edges = asarray(bins)
bins = [x_edges, y_edges]
sample = transpose(asarray([x, y]))
hist, edges = histogramdd(sample, bins, range, weights, density)
return hist, edges[0], edges[1]
@_wraps(np.histogramdd)
def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, List[Array]]:
if weights is None:
_check_arraylike("histogramdd", sample)
sample, = _promote_dtypes_inexact(sample)
else:
_check_arraylike("histogramdd", sample, weights)
if shape(weights) != shape(sample)[:1]:
raise ValueError("should have one weight for each sample.")
sample, weights = _promote_dtypes_inexact(sample, weights)
N, D = shape(sample)
if range is not None and (
len(range) != D or _any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type]
raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence "
f"of {D} pairs or Nones; got {range=}")
try:
num_bins = len(bins) # type: ignore[arg-type]
except TypeError:
# when bin_size is integer, the same bin is used for each dimension
bins_per_dimension: List[ArrayLike] = D * [bins] # type: ignore[assignment]
else:
if num_bins != D:
raise ValueError("should be a bin for each dimension.")
bins_per_dimension = list(bins) # type: ignore[arg-type]
bin_idx_by_dim: List[Array] = []
bin_edges_by_dim: List[Array] = []
for i in builtins.range(D):
range_i = None if range is None else range[i]
bin_edges = histogram_bin_edges(sample[:, i], bins_per_dimension[i], range_i, weights)
bin_idx = searchsorted(bin_edges, sample[:, i], side='right')
bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx)
bin_idx_by_dim.append(bin_idx)
bin_edges_by_dim.append(bin_edges)
nbins = tuple(len(bin_edges) + 1 for bin_edges in bin_edges_by_dim)
dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim]
xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip')
hist = bincount(xy, weights, length=math.prod(nbins))
hist = reshape(hist, nbins)
core = D*(slice(1, -1),)
hist = hist[core]
if density:
hist = hist.astype(sample.dtype)
hist /= hist.sum()
for norm in ix_(*dedges):
hist /= norm
return hist, bin_edges_by_dim
_ARRAY_VIEW_DOC = """
The JAX version of this function may in some cases return a copy rather than a
view of the input.
"""
@_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array:
_stackable(a) or _check_arraylike("transpose", a)
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
return lax.transpose(a, axes_)
@_wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('k', 'axes'))
def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
_check_arraylike("rot90", m)
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, ndim(m))
ax2 = _canonicalize_axis(ax2, ndim(m))
if ax1 == ax2:
raise ValueError("Axes must be different") # same as numpy error
k = k % 4
if k == 0:
return asarray(m)
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = list(range(ndim(m)))
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
if k == 1:
return transpose(flip(m, ax2), perm)
else:
return flip(transpose(m, perm), ax2)
@_wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
_check_arraylike("flip", m)
return _flip(asarray(m), _ensure_optional_axes(axis))
@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
if axis is None:
return lax.rev(m, list(range(len(shape(m)))))
axis = _ensure_index_tuple(axis)
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
@_wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC)
def fliplr(m: ArrayLike) -> Array:
_check_arraylike("fliplr", m)
return _flip(asarray(m), 1)
@_wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC)
def flipud(m: ArrayLike) -> Array:
_check_arraylike("flipud", m)
return _flip(asarray(m), 0)
@_wraps(np.iscomplex)
@jit
def iscomplex(x: ArrayLike) -> Array:
i = imag(x)
return lax.ne(i, _lax_const(i, 0))
@_wraps(np.isreal)
@jit
def isreal(x: ArrayLike) -> Array:
i = imag(x)
return lax.eq(i, _lax_const(i, 0))
@_wraps(np.angle)
@partial(jit, static_argnames=['deg'])
def angle(z: ArrayLike, deg: bool = False) -> Array:
re = real(z)
im = imag(z)
dtype = _dtype(re)
if not issubdtype(dtype, inexact) or (
issubdtype(_dtype(z), floating) and ndim(z) == 0):
dtype = dtypes.canonicalize_dtype(float_)
re = lax.convert_element_type(re, dtype)
im = lax.convert_element_type(im, dtype)
result = lax.atan2(im, re)
return degrees(result) if deg else result
@_wraps(np.diff)
@partial(jit, static_argnames=('n', 'axis'))
def diff(a: ArrayLike, n: int = 1, axis: int = -1,
prepend: Optional[ArrayLike] = None,
append: Optional[ArrayLike] = None) -> Array:
_check_arraylike("diff", a)
arr = asarray(a)
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff")
if n == 0:
return arr
if n < 0:
raise ValueError(f"order must be non-negative but got {n}")
if arr.ndim == 0:
raise ValueError(f"diff requires input that is at least one dimensional; got {a}")
nd = arr.ndim
axis = _canonicalize_axis(axis, nd)
combined: List[Array] = []
if prepend is not None:
_check_arraylike("diff", prepend)
if isscalar(prepend):
shape = list(arr.shape)
shape[axis] = 1
prepend = broadcast_to(prepend, tuple(shape))
combined.append(asarray(prepend))
combined.append(arr)
if append is not None:
_check_arraylike("diff", append)
if isscalar(append):
shape = list(arr.shape)
shape[axis] = 1
append = broadcast_to(append, tuple(shape))
combined.append(asarray(append))
if len(combined) > 1:
arr = concatenate(combined, axis)
slice1 = [slice(None)] * nd
slice2 = [slice(None)] * nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1_tuple = tuple(slice1)
slice2_tuple = tuple(slice2)
op = not_equal if arr.dtype == np.bool_ else subtract
for _ in range(n):
arr = op(arr[slice1_tuple], arr[slice2_tuple])
return arr
_EDIFF1D_DOC = """\
Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
loses precision.
"""
@_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC)
@jit
def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None,
to_begin: Optional[ArrayLike] = None) -> Array:
_check_arraylike("ediff1d", ary)
arr = ravel(ary)
result = lax.sub(arr[1:], arr[:-1])
if to_begin is not None:
_check_arraylike("ediff1d", to_begin)
result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result))
if to_end is not None:
_check_arraylike("ediff1d", to_end)
result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype))))
return result
@_wraps(np.gradient, skip_params=['edge_order'])
@partial(jit, static_argnames=('axis', 'edge_order'))
def gradient(f: ArrayLike, *varargs: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
edge_order: Optional[int] = None) -> Union[Array, List[Array]]:
if edge_order is not None:
raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
a, *spacing = _promote_args_inexact("gradient", f, *varargs)
def gradient_along_axis(a, h, axis):
sliced = partial(lax.slice_in_dim, a, axis=axis)
a_grad = concatenate((
(sliced(1, 2) - sliced(0, 1)), # upper edge
(sliced(2, None) - sliced(None, -2)) * 0.5, # inner
(sliced(-1, None) - sliced(-2, -1)), # lower edge
), axis)
return a_grad / h
if axis is None:
axis_tuple = tuple(range(a.ndim))
else:
if isinstance(axis, int):
axis = (axis,)
elif not isinstance(axis, tuple) and not isinstance(axis, list):
raise ValueError("Give `axis` either as int or iterable")
elif len(axis) == 0:
return []
axis_tuple = tuple(_canonicalize_axis(i, a.ndim) for i in axis)
if _min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
if len(spacing) == 0:
dx: Sequence[ArrayLike] = [1.0] * len(axis_tuple)
elif len(spacing) == 1:
dx = list(spacing) * len(axis_tuple)
elif len(spacing) == len(axis_tuple):
dx = list(spacing)
else:
TypeError(f"Invalid number of spacing arguments {len(spacing)} for {axis=}")
if ndim(dx[0]) != 0:
raise NotImplementedError("Non-constant spacing not implemented")
a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis_tuple, dx)]
return a_grad[0] if len(axis_tuple) == 1 else a_grad
@_wraps(np.isrealobj)
def isrealobj(x: Any) -> bool:
return not iscomplexobj(x)
@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a: ArrayLike, newshape: Union[DimSize, Shape], order: str = "C") -> Array:
_stackable(a) or _check_arraylike("reshape", a)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
return _reshape(asarray(a), newshape, order=order)
def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
"""Fixes a -1 value in newshape, if present."""
# other errors, like having more than one -1, are caught downstream, in
# reshape_shape_rule.
try:
iter(newshape) # type: ignore[arg-type]
except:
iterable = False
else:
iterable = True
newshape = core.canonicalize_shape(newshape if iterable else [newshape]) # type: ignore[arg-type]
return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
if core.symbolic_equal_dim(d, -1) else d
for d in newshape)
def _item(a: Array) -> Any:
"""Copy an element of an array to a standard Python scalar and return it."""
if dtypes.issubdtype(a.dtype, np.complexfloating):
return complex(a)
elif dtypes.issubdtype(a.dtype, np.floating):
return float(a)
elif dtypes.issubdtype(a.dtype, np.integer):
return int(a)
elif dtypes.issubdtype(a.dtype, np.bool_):
return bool(a)
else:
raise TypeError(a.dtype)
def _reshape(a: Array, *args: Any, order: str = "C") -> Array:
"""Returns an array containing the same data with a new shape.
Refer to :func:`jax.numpy.reshape` for full documentation.
"""
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
if order == "C":
return lax.reshape(a, newshape, None)
elif order == "F":
dims = list(range(ndim(a))[::-1])
return lax.reshape(a, newshape[::-1], dims).T
elif order == "A":
raise NotImplementedError("np.reshape order=A is not implemented.")
else:
raise ValueError(f"Unexpected value for 'order' argument: {order}.")
def _transpose(a: Array, *args: Any) -> Array:
"""Returns a view of the array with axes transposed.
Refer to :func:`jax.numpy.transpose` for full documentation.
"""
if not args:
axis = None
elif len(args) == 1:
axis = args[0] if args[0] is None else _ensure_index_tuple(args[0])
else:
axis = _ensure_index_tuple(args)
return transpose(a, axis)
@_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a: ArrayLike, order: str = "C") -> Array:
_stackable(a) or _check_arraylike("ravel", a)
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
@_wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...],
mode: str = 'raise', order: str = 'C') -> Array:
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
_check_arraylike("ravel_multi_index", *multi_index)
multi_index_arr = [asarray(i) for i in multi_index]
for index in multi_index_arr:
if mode == 'raise':
core.concrete_or_error(array, index,
"The error occurred because ravel_multi_index was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if not issubdtype(_dtype(index), integer):
raise TypeError("only int indices permitted")
if mode == "raise":
if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)):
raise ValueError("invalid entry in coordinates array")
elif mode == "clip":
multi_index_arr = [clip(i, 0, d - 1) for i, d in zip(multi_index_arr, dims)]
elif mode == "wrap":
multi_index_arr = [i % d for i, d in zip(multi_index_arr, dims)]
else:
raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'")
if order == "F":
strides = np.cumprod((1,) + dims[:-1])
elif order == "C":
strides = np.cumprod((1,) + dims[1:][::-1])[::-1]
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
result = array(0, dtype=(multi_index_arr[0].dtype if multi_index_arr
else dtypes.canonicalize_dtype(int_)))
for i, s in zip(multi_index_arr, strides):
result = result + i * int(s)
return result
_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped into the valid range.
"""
@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
_check_arraylike("unravel_index", indices)
indices_arr = asarray(indices)
# Note: we do not convert shape to an array, because it may be passed as a
# tuple of weakly-typed values, and asarray() would strip these weak types.
try:
shape = list(shape)
except TypeError:
shape = [shape]
if _any(ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices_arr, out_indices[i] = divmod(indices_arr, s)
oob_pos = indices_arr > 0
oob_neg = indices_arr < -1
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in safe_zip(shape, out_indices))
@_wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
def resize(a: ArrayLike, new_shape: Shape) -> Array:
_check_arraylike("resize", a)
new_shape = _ensure_index_tuple(new_shape)
if _any(dim_length < 0 for dim_length in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")
arr = ravel(a)
new_size = math.prod(new_shape)
if arr.size == 0 or new_size == 0:
return zeros_like(arr, shape=new_shape)
repeats = ceil_of_ratio(new_size, arr.size)
arr = tile(arr, repeats)[:new_size]
return reshape(arr, new_shape)
@_wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC)
def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
_check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: Tuple[int]) -> Array:
if axis is None:
a_shape = shape(a)
if not core.is_constant_shape(a_shape):
# We do not even know the rank of the output if the input shape is not known
raise ValueError("jnp.squeeze with axis=None is not supported with shape polymorphism")
axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
return lax.squeeze(a, axis)
@_wraps(np.expand_dims)
def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array:
_stackable(a) or _check_arraylike("expand_dims", a)
axis = _ensure_index_tuple(axis)
if hasattr(a, "expand_dims"):
return a.expand_dims(axis) # type: ignore
return lax.expand_dims(a, axis)
@_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
_check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
return lax.transpose(a, list(perm))
@_wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC)
def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]]) -> Array:
_check_arraylike("moveaxis", a)
return _moveaxis(asarray(a), _ensure_index_tuple(source),
_ensure_index_tuple(destination))
@partial(jit, static_argnames=('source', 'destination'), inline=True)
def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) -> Array:
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
if len(source) != len(destination):
raise ValueError("Inconsistent number of elements: {} vs {}"
.format(len(source), len(destination)))
perm = [i for i in range(ndim(a)) if i not in source]
for dest, src in sorted(zip(destination, source)):
perm.insert(dest, src)
return lax.transpose(a, perm)
@_wraps(np.isclose)
@partial(jit, static_argnames=('equal_nan',))
def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08,
equal_nan: bool = False) -> Array:
a, b = _promote_args("isclose", a, b)
dtype = _dtype(a)
if issubdtype(dtype, inexact):
if issubdtype(dtype, complexfloating):
dtype = _complex_elem_type(dtype)
rtol = lax.convert_element_type(rtol, dtype)
atol = lax.convert_element_type(atol, dtype)
out = lax.le(
lax.abs(lax.sub(a, b)),
lax.add(atol, lax.mul(rtol, lax.abs(b))))
# This corrects the comparisons for infinite and nan values
a_inf = isinf(a)
b_inf = isinf(b)
any_inf = logical_or(a_inf, b_inf)
both_inf = logical_and(a_inf, b_inf)
# Make all elements where either a or b are infinite to False
out = logical_and(out, logical_not(any_inf))
# Make all elements where both a or b are the same inf to True
same_value = lax.eq(a, b)
same_inf = logical_and(both_inf, same_value)
out = logical_or(out, same_inf)
# Make all elements where either a or b is NaN to False
a_nan = isnan(a)
b_nan = isnan(b)
any_nan = logical_or(a_nan, b_nan)
out = logical_and(out, logical_not(any_nan))
if equal_nan:
# Make all elements where both a and b is NaN to True
both_nan = logical_and(a_nan, b_nan)
out = logical_or(out, both_nan)
return out
else:
return lax.eq(a, b)
@_wraps(np.interp)
@jit
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: Optional[ArrayLike] = None,
right: Optional[ArrayLike] = None,
period: Optional[ArrayLike] = None) -> Array:
_check_arraylike("interp", x, xp, fp)
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x_arr, xp_arr = _promote_dtypes_inexact(x, xp)
fp_arr, = _promote_dtypes_inexact(fp)
del x, xp, fp
if dtypes.issubdtype(x_arr.dtype, np.complexfloating):
raise ValueError("jnp.interp: complex x values not supported.")
if period is not None:
if ndim(period) != 0:
raise ValueError(f"period must be a scalar; got {period}")
period = abs(period)
x_arr = x_arr % period
xp_arr = xp_arr % period
xp_arr, fp_arr = lax.sort_key_val(xp_arr, fp_arr)
xp_arr = concatenate([xp_arr[-1:] - period, xp_arr, xp_arr[:1] + period])
fp_arr = concatenate([fp_arr[-1:], fp_arr, fp_arr[:1]])
i = clip(searchsorted(xp_arr, x_arr, side='right'), 1, len(xp_arr) - 1)
df = fp_arr[i] - fp_arr[i - 1]
dx = xp_arr[i] - xp_arr[i - 1]
delta = x_arr - xp_arr[i - 1]
epsilon = np.spacing(np.finfo(xp_arr.dtype).eps)
dx0 = lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small.
f = where(dx0, fp_arr[i - 1], fp_arr[i - 1] + (delta / where(dx0, 1, dx)) * df)
left_arr: ArrayLike = fp_arr[0] if left is None else left
right_arr: ArrayLike = fp_arr[-1] if right is None else right
if period is None:
f = where(x_arr < xp_arr[0], left_arr, f)
f = where(x_arr > xp_arr[-1], right_arr, f)
return f
@overload
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]: ...
@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Array: ...
@overload
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]: ...
@_wraps(np.where,
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully. Alternatively, you can use the optional ``size`` keyword to
statically specify the expected size of the output.\n\n
Special care is needed when the ``x`` or ``y`` input to
:py:func:`jax.numpy.where` could have a value of NaN.
Specifically, when a gradient is taken
with :py:func:`jax.grad` (reverse-mode differentiation), a NaN in either
``x`` or ``y`` will propagate into the gradient, regardless of the value
of ``condition``. More information on this behavior and workarounds
is available in the JAX FAQ:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where"""),
extra_params=_dedent("""
size : int, optional
Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first
``size`` elements of the result will be returned. If there are fewer elements than ``size``
indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]:
if x is None and y is None:
_check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
else:
_check_arraylike("where", condition, x, y)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
return _where(condition, x, y)
@_wraps(np.select)
def select(condlist, choicelist, default=0):
if len(condlist) != len(choicelist):
msg = "condlist must have length equal to choicelist ({} vs {})"
raise ValueError(msg.format(len(condlist), len(choicelist)))
if len(condlist) == 0:
raise ValueError("condlist must be non-empty")
choices = _promote_dtypes(default, *choicelist)
choicelist = choices[1:]
output = choices[0]
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
output = where(cond, choice, output)
return output
@_wraps(np.bincount, lax_description="""\
Jax adds the optional `length` parameter which specifies the output length, and
defaults to ``x.max() + 1``. It must be specified for bincount to be compiled
with non-static operands. Values larger than the specified length will be discarded.
If `length` is specified, `minlength` will be ignored.
Additionally, while ``np.bincount`` raises an error if the input array contains
negative values, ``jax.numpy.bincount`` clips negative values to zero.
""")
def bincount(x: ArrayLike, weights: Optional[ArrayLike] = None,
minlength: int = 0, *, length: Optional[int] = None) -> Array:
_check_arraylike("bincount", x)
if not issubdtype(_dtype(x), integer):
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
if ndim(x) != 1:
raise ValueError("only 1-dimensional input supported.")
minlength = core.concrete_or_error(operator.index, minlength,
"The error occurred because of argument 'minlength' of jnp.bincount.")
if length is None:
x_arr = core.concrete_or_error(asarray, x,
"The error occurred because of argument 'x' of jnp.bincount. "
"To avoid this error, pass a static `length` argument.")
length = _max(minlength, x_arr.size and int(x_arr.max()) + 1)
else:
length = core.concrete_or_error(operator.index, length,
"The error occurred because of argument 'length' of jnp.bincount.")
if weights is None:
weights = np.array(1, dtype=int_)
elif shape(x) != shape(weights):
raise ValueError("shape of weights must match shape of x.")
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
@overload
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
@overload
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
) -> Tuple[Union[int, core.Tracer], ...]: ...
@_wraps(getattr(np, "broadcast_shapes", None))
def broadcast_shapes(*shapes):
if not shapes:
return ()
shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes]
return lax.broadcast_shapes(*shapes)
@_wraps(np.broadcast_arrays, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")
def broadcast_arrays(*args: ArrayLike) -> List[Array]:
return _broadcast_arrays(*args)
@_wraps(np.broadcast_to, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")
def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
return _broadcast_to(array, shape)
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
axis: int = 0) -> List[Array]:
_check_arraylike(op, ary)
ary = asarray(ary)
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
size = ary.shape[axis]
if isinstance(indices_or_sections, (tuple, list)):
indices_or_sections = np.array(
[core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections], np.int64)
split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
[np.int64(size)]])
elif (isinstance(indices_or_sections, (np.ndarray, Array)) and
indices_or_sections.ndim > 0):
indices_or_sections = np.array(
[core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections], np.int64)
split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
[np.int64(size)]])
else:
indices_or_sections = core.concrete_or_error(np.int64, indices_or_sections,
f"in jax.numpy.{op} argument 1")
part_size, r = _divmod(size, indices_or_sections) # type: ignore[misc]
if r == 0:
split_indices = [np.int64(i) * part_size # type: ignore
for i in range(indices_or_sections + 1)] # type: ignore
elif op == "array_split":
split_indices = (
[np.int64(i) * (part_size + 1) for i in range(r + 1)] + # type: ignore
[np.int64(i) * part_size + ((r + 1) * (part_size + 1) - 1)
for i in range(indices_or_sections - r)])
else:
raise ValueError("array split does not result in an equal division")
starts, ends = [0] * ndim(ary), shape(ary)
_subval = lambda x, i, v: subvals(x, [(i, v)])
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
for start, end in zip(split_indices[:-1], split_indices[1:])]
@_wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
return _split("split", ary, indices_or_sections, axis=axis)
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
@_wraps(getattr(np, op), update_doc=False)
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
# for 1-D array, hsplit becomes vsplit
nonlocal axis
_check_arraylike(op, ary)
a = asarray(ary)
if axis == 1 and len(a.shape) == 1:
axis = 0
return _split(op, ary, indices_or_sections, axis=axis)
return f
vsplit = _split_on_axis("vsplit", axis=0)
hsplit = _split_on_axis("hsplit", axis=1)
dsplit = _split_on_axis("dsplit", axis=2)
@_wraps(np.array_split)
def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)
@_wraps(np.clip, skip_params=['out'])
@jit
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None,
a_max: Optional[ArrayLike] = None, out: None = None) -> Array:
_check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = maximum(a_min, a)
if a_max is not None:
a = minimum(a_max, a)
return asarray(a)
@_wraps(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
_check_arraylike("round", a)
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
dtype = _dtype(a)
if issubdtype(dtype, integer):
if decimals < 0:
raise NotImplementedError(
"integer np.round not implemented for decimals < 0")
return asarray(a) # no-op on integer types
def _round_float(x: ArrayLike) -> Array:
if decimals == 0:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
# TODO(phawkins): the strategy of rescaling the value isn't necessarily a
# good one since we may be left with an incorrectly rounded value at the
# end due to precision problems. As a workaround for float16, convert to
# float32,
x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x
factor = _lax_const(x, 10 ** decimals)
out = lax.div(lax.round(lax.mul(x, factor),
lax.RoundingMethod.TO_NEAREST_EVEN), factor)
return lax.convert_element_type(out, dtype) if dtype == np.float16 else out
if issubdtype(dtype, complexfloating):
return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a)))
else:
return _round_float(a)
around = round
round_ = round
@_wraps(np.fix, skip_params=['out'])
@jit
def fix(x: ArrayLike, out: None = None) -> Array:
_check_arraylike("fix", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
zero = _lax_const(x, 0)
return where(lax.ge(x, zero), floor(x), ceil(x))
@_wraps(np.nan_to_num)
@jit
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
posinf: Optional[ArrayLike] = None,
neginf: Optional[ArrayLike] = None) -> Array:
del copy
_check_arraylike("nan_to_num", x)
dtype = _dtype(x)
if not issubdtype(dtype, inexact):
return asarray(x)
if issubdtype(dtype, complexfloating):
return lax.complex(
nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf))
info = finfo(dtypes.canonicalize_dtype(dtype))
posinf = info.max if posinf is None else posinf
neginf = info.min if neginf is None else neginf
out = where(isnan(x), asarray(nan, dtype=dtype), x)
out = where(isposinf(out), asarray(posinf, dtype=dtype), out)
out = where(isneginf(out), asarray(neginf, dtype=dtype), out)
return out
@_wraps(np.allclose)
@partial(jit, static_argnames=('equal_nan',))
def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array:
_check_arraylike("allclose", a, b)
return all(isclose(a, b, rtol, atol, equal_nan))
_NONZERO_DOC = """\
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.nonzero`` to be used within some of JAX's
transformations.
"""
_NONZERO_EXTRA_PARAMS = """
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there are
fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero.
"""
@_wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def nonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]:
_check_arraylike("nonzero", a)
arr = atleast_1d(a)
del a
mask = arr if arr.dtype == bool else (arr != 0)
if size is None:
size = mask.sum()
size = core.concrete_or_error(operator.index, size,
"The size argument of jnp.nonzero must be statically specified "
"to use jnp.nonzero within JAX transformations.")
if arr.size == 0 or size == 0:
return tuple(zeros(size, int) for dim in arr.shape)
flat_indices = cumsum(bincount(cumsum(mask), length=size))
strides = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(int_)
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape))
if size is not None and fill_value is not None:
fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,)
if _any(_shape(val) != () for val in fill_value_tup):
raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}")
fill_mask = arange(size) >= mask.sum()
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
return out
@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None) -> Array:
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
@_wraps(np.unwrap)
@partial(jit, static_argnames=('axis',))
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
axis: int = -1, period: ArrayLike = 2 * pi) -> Array:
_check_arraylike("unwrap", p)
p = asarray(p)
if issubdtype(p.dtype, np.complexfloating):
raise ValueError("jnp.unwrap does not support complex inputs.")
if p.shape[axis] == 0:
return _promote_dtypes_inexact(p)[0]
if discont is None:
discont = period / 2
interval = period / 2
dd = diff(p, axis=axis)
ddmod = mod(dd + interval, period) - interval
ddmod = where((ddmod == -interval) & (dd > 0), interval, ddmod)
ph_correct = where(abs(dd) < discont, 0, ddmod - dd)
up = concatenate((
lax.slice_in_dim(p, 0, 1, axis=axis),
lax.slice_in_dim(p, 1, None, axis=axis) + cumsum(ph_correct, axis=axis)
), axis=axis)
return up
### Padding
PadValueLike = Union[T, Sequence[T], Sequence[Sequence[T]]]
PadValue = Tuple[Tuple[T, T], ...]
# TODO(jakevdp): make this a protocol
PadStatFunc = Callable[..., Array]
def _broadcast_to_pairs(nvals: PadValueLike, nd: int, name: str) -> PadValue:
try:
nvals = np.asarray(tree_map(
lambda x: core.concrete_or_error(None, x, context=f"{name} argument of jnp.pad"),
nvals))
except ValueError as e:
# In numpy 1.24
if "array has an inhomogeneous shape" in str(e):
raise TypeError(f'`{name}` entries must be the same shape: {nvals}') from e
raise
def as_scalar_dim(v):
if core.is_special_dim_size(v) or not np.shape(v):
return v
else:
raise TypeError(f'`{name}` entries must be the same shape: {nvals}')
if nvals.shape == (nd, 2):
# ((before_1, after_1), ..., (before_N, after_N))
return tuple((as_scalar_dim(nval[0]), as_scalar_dim(nval[1])) for nval in nvals)
elif nvals.shape == (1, 2):
# ((before, after),)
v1_2 = as_scalar_dim(nvals[0, 0]), as_scalar_dim(nvals[0, 1])
return tuple(v1_2 for i in range(nd))
elif nvals.shape == (2,):
# (before, after) (not in the numpy docstring but works anyway)
v1_2 = as_scalar_dim(nvals[0]), as_scalar_dim(nvals[1])
return tuple(v1_2 for i in range(nd))
elif nvals.shape == (1,):
# (pad,)
v = as_scalar_dim(nvals[0])
return tuple((v, v) for i in range(nd))
elif nvals.shape == ():
# pad
v = as_scalar_dim(nvals.flat[0])
return tuple((v, v) for i in range(nd))
else:
raise ValueError(f"jnp.pad: {name} with {nd=} has unsupported shape {nvals.shape}. "
f"Valid shapes are ({nd}, 2), (1, 2), (2,), (1,), or ().")
def _check_no_padding(axis_padding: Tuple[Any, Any], mode: str):
if (axis_padding[0] > 0 or axis_padding[1] > 0):
msg = "Cannot apply '{}' padding to empty axis"
raise ValueError(msg.format(mode))
def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array:
nd = ndim(array)
constant_values = broadcast_to(constant_values, (nd, 2))
constant_values = lax_internal._convert_element_type(
constant_values, array.dtype, dtypes.is_weakly_typed(array))
for i in range(nd):
widths = [(0, 0, 0)] * nd
widths[i] = (pad_width[i][0], 0, 0)
array = lax.pad(array, constant_values[i, 0], widths)
widths[i] = (0, pad_width[i][1], 0)
array = lax.pad(array, constant_values[i, 1], widths)
return array
def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array:
for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "wrap")
continue
size = array.shape[i]
repeats, (left_remainder, right_remainder) = np.divmod(pad_width[i], size)
total_repeats = repeats.sum() + 1
parts = []
if left_remainder:
parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)]
parts += total_repeats * [array]
if right_remainder:
parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)]
array = lax.concatenate(parts, dimension=i)
return array
def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
mode: str, reflect_type: str) -> Array:
assert mode in ("symmetric", "reflect")
assert reflect_type in ("even", "odd")
for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], mode)
continue
n = array.shape[i]
offset = 1 if (mode == "reflect" and n > 1) else 0
def build_padding(array, padding, before):
if before:
edge = lax.slice_in_dim(array, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(array, -1, None, axis=i)
while padding > 0:
curr_pad = _min(padding, n - offset)
padding -= curr_pad
if before:
start = offset
stop = offset + curr_pad
else:
start = -(curr_pad + offset)
stop = None if (mode == "symmetric" or n == 1) else -1
x = lax.slice_in_dim(array, start, stop, axis=i)
x = flip(x, axis=i)
if reflect_type == 'odd':
x = 2 * edge - x
if n > 1:
if before:
edge = lax.slice_in_dim(x, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(x, -1, None, axis=i)
if before:
array = lax.concatenate([x, array], dimension=i)
else:
array = lax.concatenate([array, x], dimension=i)
return array
array = build_padding(array, pad_width[i][0], before=True)
array = build_padding(array, pad_width[i][1], before=False)
return array
def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array:
nd = ndim(array)
for i in range(nd):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "edge")
continue
n = array.shape[i]
npad_before, npad_after = pad_width[i]
edge_before = lax.slice_in_dim(array, 0, 1, axis=i)
pad_before = repeat(edge_before, npad_before, axis=i)
edge_after = lax.slice_in_dim(array, n-1, n, axis=i)
pad_after = repeat(edge_after, npad_after, axis=i)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
def _pad_linear_ramp(array: Array, pad_width: PadValue[int],
end_values: PadValue[ArrayLike]) -> Array:
for axis in range(ndim(array)):
edge_before = lax.slice_in_dim(array, 0, 1, axis=axis)
edge_after = lax.slice_in_dim(array, -1, None, axis=axis)
ramp_before = linspace(
start=end_values[axis][0],
stop=edge_before.squeeze(axis), # Dimension is replaced by linspace
num=pad_width[axis][0],
endpoint=False,
dtype=array.dtype,
axis=axis
)
ramp_before = lax_internal._convert_element_type(
ramp_before, weak_type=dtypes.is_weakly_typed(array))
ramp_after = linspace(
start=end_values[axis][1],
stop=edge_after.squeeze(axis), # Dimension is replaced by linspace
num=pad_width[axis][1],
endpoint=False,
dtype=array.dtype,
axis=axis
)
ramp_after = lax_internal._convert_element_type(
ramp_after, weak_type=dtypes.is_weakly_typed(array))
# Reverse linear space in appropriate dimension
ramp_after = flip(ramp_after, axis)
array = lax.concatenate([ramp_before, array, ramp_after], dimension=axis)
return array
def _pad_stats(array: Array, pad_width: PadValue[int], stat_length: Optional[PadValue[int]],
stat_func: PadStatFunc) -> Array:
nd = ndim(array)
for i in range(nd):
if stat_length is None:
stat_before = stat_func(array, axis=i, keepdims=True)
stat_after = stat_before
else:
array_length = array.shape[i]
length_before, length_after = stat_length[i]
if length_before == 0 or length_after == 0:
raise ValueError("stat_length of 0 yields no value for padding")
# Limit stat_length to length of array.
length_before = _min(length_before, array_length)
length_after = _min(length_after, array_length)
slice_before = lax.slice_in_dim(array, 0, length_before, axis=i)
slice_after = lax.slice_in_dim(array, -length_after, None, axis=i)
stat_before = stat_func(slice_before, axis=i, keepdims=True)
stat_after = stat_func(slice_after, axis=i, keepdims=True)
if np.issubdtype(array.dtype, np.integer):
stat_before = round(stat_before)
stat_after = round(stat_after)
stat_before = lax_internal._convert_element_type(
stat_before, array.dtype, dtypes.is_weakly_typed(array))
stat_after = lax_internal._convert_element_type(
stat_after, array.dtype, dtypes.is_weakly_typed(array))
npad_before, npad_after = pad_width[i]
pad_before = repeat(stat_before, npad_before, axis=i)
pad_after = repeat(stat_after, npad_after, axis=i)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array:
# Note: jax.numpy.empty = jax.numpy.zeros
for i in range(ndim(array)):
shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:]
pad_before = empty_like(array, shape=shape_before)
shape_after = array.shape[:i] + (pad_width[i][1],) + array.shape[i + 1:]
pad_after = empty_like(array, shape=shape_after)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any], **kwargs) -> Array:
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
padded = _pad_constant(array, pad_width, asarray(0))
for axis in range(ndim(padded)):
padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs)
return padded
@partial(jit, static_argnums=(1, 2, 4, 5, 6))
def _pad(array: ArrayLike, pad_width: PadValueLike[int],
mode: Union[str, PadStatFunc],
constant_values: ArrayLike, stat_length: PadValueLike[int],
end_values: PadValueLike[ArrayLike], reflect_type: str):
array = asarray(array)
nd = ndim(array)
if nd == 0:
return array
stat_funcs: Dict[str, PadStatFunc] = {
"maximum": amax, "minimum": amin, "mean": mean, "median": median}
pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
pad_width_arr = np.array(pad_width)
if pad_width_arr.shape != (nd, 2):
raise ValueError(f"Expected pad_width to have shape {(nd, 2)}; got {pad_width_arr.shape}.")
if np.any(pad_width_arr < 0):
raise ValueError("index can't contain negative values")
if mode == "constant":
return _pad_constant(array, pad_width, asarray(constant_values))
elif mode == "wrap":
return _pad_wrap(array, pad_width)
elif mode in ("symmetric", "reflect"):
return _pad_symmetric_or_reflect(array, pad_width, str(mode), reflect_type)
elif mode == "edge":
return _pad_edge(array, pad_width)
elif mode == "linear_ramp":
end_values = _broadcast_to_pairs(end_values, nd, "end_values")
return _pad_linear_ramp(array, pad_width, end_values)
elif mode in stat_funcs:
if stat_length is not None:
stat_length = _broadcast_to_pairs(stat_length, nd, "stat_length")
return _pad_stats(array, pad_width, stat_length, stat_funcs[str(mode)])
elif mode == "empty":
return _pad_empty(array, pad_width)
else:
assert False, ("Should not be reached since pad already handled unsupported and"
"not implemented modes")
@_wraps(np.pad, lax_description="""\
Unlike numpy, JAX "function" mode's argument (which is another function) should return
the modified array. This is because Jax arrays are immutable.
(In numpy, "function" mode's argument should modify a rank 1 array in-place.)
""")
def pad(array: ArrayLike, pad_width: PadValueLike[int],
mode: Union[str, Callable[..., Any]] = "constant", **kwargs) -> Array:
_check_arraylike("pad", array)
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
if pad_width and not _all(core.is_dim(p[0]) and core.is_dim(p[1])
for p in pad_width):
raise TypeError('`pad_width` must be of integral type.')
if callable(mode):
return _pad_func(asarray(array), pad_width, mode, **kwargs)
allowed_kwargs = {
'empty': [], 'edge': [], 'wrap': [],
'constant': ['constant_values'],
'linear_ramp': ['end_values'],
'maximum': ['stat_length'],
'mean': ['stat_length'],
'median': ['stat_length'],
'minimum': ['stat_length'],
'reflect': ['reflect_type'],
'symmetric': ['reflect_type'],
}
try:
unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode]) # type: ignore[call-overload]
except KeyError:
msg = "Unimplemented padding mode '{}' for np.pad."
raise NotImplementedError(msg.format(mode))
if unsupported_kwargs:
raise ValueError("unsupported keyword arguments for mode '{}': {}"
.format(mode, unsupported_kwargs))
# Set default value if not given.
constant_values = kwargs.get('constant_values', 0)
stat_length = kwargs.get('stat_length', None)
end_values = kwargs.get('end_values', 0)
reflect_type = kwargs.get('reflect_type', "even")
return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
### Array-creation functions
@_wraps(np.stack, skip_params=['out'])
def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: int = 0, out: None = None, dtype: Optional[DTypeLike] = None) -> Array:
if not len(arrays):
raise ValueError("Need at least one array to stack.")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
if isinstance(arrays, (np.ndarray, Array)):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
else:
_stackable(*arrays) or _check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis, dtype=dtype)
@_wraps(np.tile)
def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array:
_stackable(A) or _check_arraylike("tile", A)
try:
iter(reps) # type: ignore[arg-type]
except TypeError:
reps_tup: Tuple[DimSize, ...] = (reps,)
else:
reps_tup = tuple(reps) # type: ignore[assignment,arg-type]
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
for rep in reps_tup)
A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A)
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
[k for pair in zip(reps_tup, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps_tup)))
def _concatenate_array(arr: ArrayLike, axis: Optional[int],
dtype: Optional[DTypeLike] = None) -> Array:
# Fast path for concatenation when the input is an ndarray rather than a list.
arr = asarray(arr, dtype=dtype)
if arr.ndim == 0 or arr.shape[0] == 0:
raise ValueError("Need at least one array to concatenate.")
if axis is None:
return lax.reshape(arr, (arr.size,))
if arr.ndim == 1:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = _canonicalize_axis(axis, arr.ndim - 1)
shape = arr.shape[1:axis + 1] + (arr.shape[0] * arr.shape[axis + 1],) + arr.shape[axis + 2:]
dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)]
return lax.reshape(arr, shape, dimensions)
@_wraps(np.concatenate)
def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(arrays, (np.ndarray, Array)):
return _concatenate_array(arrays, axis, dtype=dtype)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if hasattr(arrays[0], "concatenate"):
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr]
axis = _canonicalize_axis(axis, ndim(arrays[0]))
if dtype is None:
arrays_out = _promote_dtypes(*arrays)
else:
arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653).
k = 16
while len(arrays_out) > 1:
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
for i in range(0, len(arrays_out), k)]
return arrays_out[0]
@_wraps(np.vstack)
def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(atleast_2d)(tup)
else:
arrs = [atleast_2d(m) for m in tup]
return concatenate(arrs, axis=0, dtype=dtype)
row_stack = vstack
@_wraps(np.hstack)
def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
else:
arrs = [atleast_1d(m) for m in tup]
arr0_ndim = arrs[0].ndim
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)
@_wraps(np.dstack)
def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(atleast_3d)(tup)
else:
arrs = [atleast_3d(m) for m in tup]
return concatenate(arrs, axis=2, dtype=dtype)
@_wraps(np.column_stack)
def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array:
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
else:
arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)]
return concatenate(arrs, 1)
@_wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
_check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
raise ValueError("`a` array must be integer typed")
N = len(choices)
if mode == 'raise':
arr: Array = core.concrete_or_error(asarray, a,
"The error occurred because jnp.choose was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if any((arr < 0) | (arr >= N)):
raise ValueError("invalid entry in choice array")
elif mode == 'wrap':
arr = asarray(a) % N
elif mode == 'clip':
arr = clip(a, 0, N - 1)
else:
raise ValueError(f"mode={mode!r} not understood. Must be 'raise', 'wrap', or 'clip'")
arr, *choices = broadcast_arrays(arr, *choices)
return array(choices)[(arr,) + indices(arr.shape, sparse=True)]
def _atleast_nd(x: ArrayLike, n: int) -> Array:
m = ndim(x)
return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x)
def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]:
if isinstance(xs, tuple):
raise ValueError("jax.numpy.block does not allow tuples, got {}"
.format(xs))
elif isinstance(xs, list):
if len(xs) == 0:
raise ValueError("jax.numpy.block does not allow empty list arguments")
xs_tup, depths = unzip2([_block(x) for x in xs])
if _any(d != depths[0] for d in depths[1:]):
raise ValueError("Mismatched list depths in jax.numpy.block")
rank = _max(depths[0], _max(ndim(x) for x in xs_tup))
xs_tup = tuple(_atleast_nd(x, rank) for x in xs_tup)
return concatenate(xs_tup, axis=-depths[0]), depths[0] + 1
else:
return asarray(xs), 1
@_wraps(np.block)
@jit
def block(arrays: Union[ArrayLike, List[ArrayLike]]) -> Array:
out, _ = _block(arrays)
return out
@_wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
return arr if ndim(arr) >= 1 else reshape(arr, -1)
else:
return [atleast_1d(arr) for arr in arys]
@_wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) >= 2:
return arr
elif ndim(arr) == 1:
return expand_dims(arr, axis=0)
else:
return expand_dims(arr, axis=(0, 1))
else:
return [atleast_2d(arr) for arr in arys]
@_wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_3d(*arys: ArrayLike) -> Union[Array, List[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) == 0:
arr = expand_dims(arr, axis=(0, 1, 2))
elif ndim(arr) == 1:
arr = expand_dims(arr, axis=(0, 2))
elif ndim(arr) == 2:
arr = expand_dims(arr, axis=2)
return arr
else:
return [atleast_3d(arr) for arr in arys]
_ARRAY_DOC = """
This function will create arrays on JAX's default device. For control of the
device placement of data, see :func:`jax.device_put`. More information is
available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at
https://jax.readthedocs.io/en/latest/faq.html).
"""
@_wraps(np.array, lax_description=_ARRAY_DOC)
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
order: Optional[str] = "K", ndmin: int = 0) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
# check if the given dtype is compatible with JAX
lax_internal._check_user_dtype_supported(dtype, "array")
# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
# not dtype is None. We don't assign the result because we want the raw object
# to be used for type inference below.
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)
object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf,
object)
leaves = tree_leaves(object)
if dtype is None:
# Use lattice_result_type rather than result_type to avoid canonicalization.
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
try:
dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
except TypeError:
# This happens if, e.g. one of the entries is a memoryview object.
# This is rare, so we only handle it if the normal path fails.
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
dtype = dtypes._lattice_result_type(*leaves)[0]
if not weak_type:
dtype = dtypes.canonicalize_dtype(dtype)
# We can't use the ndarray class because we need to handle internal buffers
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)
out: ArrayLike
if _all(not isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
# https://github.com/google/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications.
out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
elif isinstance(object, ndarray_types):
assert object.aval is not None
out = _array_copy(object) if copy else object
elif isinstance(object, (list, tuple)):
if object:
out = stack([asarray(elt, dtype=dtype) for elt in object])
else:
out = np.array([], dtype=dtype)
else:
try:
view = memoryview(object)
except TypeError:
pass # `object` does not support the buffer interface.
else:
return array(np.asarray(view), dtype, copy, ndmin=ndmin)
raise TypeError(f"Unexpected input type for array: {type(object)}")
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
return out_array
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
try:
dtypes.dtype(x)
except TypeError:
return np.asarray(x)
else:
return x
@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
@_wraps(np.copy, lax_description=_ARRAY_DOC)
def copy(a: ArrayLike, order: Optional[str] = None) -> Array:
_check_arraylike("copy", a)
return array(a, copy=True, order=order)
@_wraps(np.zeros_like)
def zeros_like(a: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
_check_arraylike("zeros_like", a)
lax_internal._check_user_dtype_supported(dtype, "zeros_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 0, dtype, shape)
@_wraps(np.ones_like)
def ones_like(a: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
_check_arraylike("ones_like", a)
lax_internal._check_user_dtype_supported(dtype, "ones_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 1, dtype, shape)
@_wraps(np.empty_like, lax_description="""\
Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty_like(prototype: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
_check_arraylike("empty_like", prototype)
lax_internal._check_user_dtype_supported(dtype, "empty_like")
return zeros_like(prototype, dtype=dtype, shape=shape)
@_wraps(np.full)
def full(shape: Any, fill_value: ArrayLike,
dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "full")
_check_arraylike("full", fill_value)
if ndim(fill_value) == 0:
shape = canonicalize_shape(shape)
return lax.full(shape, fill_value, dtype)
else:
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
@_wraps(np.full_like)
def full_like(a: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Any = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "full_like")
_check_arraylike("full_like", a, fill_value)
if shape is not None:
shape = canonicalize_shape(shape)
if ndim(fill_value) == 0:
return lax.full_like(a, fill_value, dtype, shape)
else:
shape = np.shape(a) if shape is None else shape
dtype = result_type(a) if dtype is None else dtype
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
@_wraps(np.zeros)
def zeros(shape: Any, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax_internal._check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape(shape)
return lax.full(shape, 0, _jnp_dtype(dtype))
@_wraps(np.ones)
def ones(shape: Any, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
shape = canonicalize_shape(shape)
lax_internal._check_user_dtype_supported(dtype, "ones")
return lax.full(shape, 1, _jnp_dtype(dtype))
@_wraps(np.empty, lax_description="""\
Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty(shape: Any, dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "empty")
return zeros(shape, dtype)
@_wraps(np.array_equal)
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return bool_(False)
if shape(a1) != shape(a2):
return bool_(False)
eq = asarray(a1 == a2)
if equal_nan:
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
return all(eq)
@_wraps(np.array_equiv)
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return bool_(False)
try:
eq = equal(a1, a2)
except ValueError:
# shapes are not broadcastable
return bool_(False)
return all(eq)
# General np.from* style functions mostly delegate to numpy.
@_wraps(np.frombuffer)
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = float,
count: int = -1, offset: int = 0) -> Array:
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
def fromfile(*args, **kwargs):
"""Unimplemented JAX wrapper for jnp.fromfile.
This function is left deliberately unimplemented because it may be non-pure and thus
unsafe for use with JIT and other JAX transformations. Consider using
``jnp.asarray(np.fromfile(...))`` instead, although care should be taken if ``np.fromfile``
is used within jax transformations because of its potential side-effect of consuming the
file object; for more information see `Common Gotchas: Pure Functions
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_.
"""
raise NotImplementedError(
"jnp.fromfile() is not implemented because it may be non-pure and thus unsafe for use "
"with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) "
"instead, although care should be taken if np.fromfile is used within a jax transformations "
"because of its potential side-effect of consuming the file object; for more information see "
"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")
def fromiter(*args, **kwargs):
"""Unimplemented JAX wrapper for jnp.fromiter.
This function is left deliberately unimplemented because it may be non-pure and thus
unsafe for use with JIT and other JAX transformations. Consider using
``jnp.asarray(np.fromiter(...))`` instead, although care should be taken if ``np.fromiter``
is used within jax transformations because of its potential side-effect of consuming the
iterable object; for more information see `Common Gotchas: Pure Functions
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_.
"""
raise NotImplementedError(
"jnp.fromiter() is not implemented because it may be non-pure and thus unsafe for use "
"with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) "
"instead, although care should be taken if np.fromiter is used within a jax transformations "
"because of its potential side-effect of consuming the iterable object; for more information see "
"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")
@_wraps(getattr(np, "from_dlpack", None))
def from_dlpack(x: Any) -> Array:
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x.__dlpack__())
@_wraps(np.fromfunction)
def fromfunction(function: Callable[..., Array], shape: Any,
*, dtype: DTypeLike = float, **kwargs) -> Array:
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
for i in range(len(shape)):
in_axes = [0 if i == j else None for j in range(len(shape))]
function = jax.vmap(function, in_axes=tuple(in_axes[::-1]))
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
@_wraps(np.fromstring)
def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array:
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
@_wraps(np.eye)
def eye(N: DimSize, M: Optional[DimSize] = None, k: int = 0,
dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "eye")
N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
M_int = N_int if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()")
if N_int < 0 or M_int < 0:
raise ValueError(f"negative dimensions are not allowed, got {N} and {M}")
k = operator.index(k)
return lax_internal._eye(_jnp_dtype(dtype), (N_int, M_int), k)
@_wraps(np.identity)
def identity(n: DimSize, dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "identity")
return eye(n, dtype=dtype)
@_wraps(np.arange)
def arange(start: DimSize, stop: Optional[DimSize] = None,
step: Optional[DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, None)
msg = "It arose in jax.numpy.arange argument `{}`.".format
if _any(core.is_special_dim_size(d) for d in (start, stop, step)):
if stop is not None or step is not None:
raise ValueError(
"jax.numpy.arange supports non-constant arguments only in "
"single-argument form. Found "
f"jax.numpy.arange({start=}, {stop=}, {step=})")
return lax.iota(dtype or int_, start)
if dtype is None:
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
dtype = _jnp_dtype(dtype)
if stop is None and step is None:
start_dtype = _dtype(start)
if not jax.config.jax_dynamic_shapes:
start = require(start, msg("stop"))
if (not dtypes.issubdtype(start_dtype, np.integer) and
not core.is_opaque_dtype(start_dtype)):
ceil_ = ceil if isinstance(start, core.Tracer) else np.ceil
start = ceil_(start).astype(int) # type: ignore
return lax.iota(dtype, start)
else:
start = require(start, msg("start"))
stop = None if stop is None else require(stop, msg("stop"))
step = None if step is None else require(step, msg("step"))
if step is None and start == 0 and stop is not None:
stop = np.ceil(stop).astype(int)
return lax.iota(dtype, stop)
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: Literal[False] = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: bool, retstep: Literal[True],
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, *, retstep: Literal[True],
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Union[Array, Tuple[Array, Array]]: ...
@_wraps(np.linspace)
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Union[Array, Tuple[Array, Array]]:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Union[Array, Tuple[Array, Array]]:
"""Implementation of linspace differentiable in start and stop args."""
lax_internal._check_user_dtype_supported(dtype, "linspace")
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
_check_arraylike("linspace", start, stop)
if dtype is None:
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
broadcast_start = broadcast_to(start, bounds_shape)
broadcast_stop = broadcast_to(stop, bounds_shape)
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
bounds_shape.insert(axis, 1)
div = (num - 1) if endpoint else num
if num > 1:
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / div
iota_shape = [1,] * len(bounds_shape)
iota_shape[axis] = div
# This approach recovers the endpoints with float32 arithmetic,
# but can lead to rounding errors for integer outputs.
real_dtype = finfo(computation_dtype).dtype
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
step = step.astype(computation_dtype)
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)
if endpoint:
out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))],
_canonicalize_axis(axis, out.ndim))
elif num == 1:
delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype)
out = reshape(broadcast_start, bounds_shape)
else: # num == 0 degenerate case, match numpy behavior
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
empty_shape.insert(axis, 0)
delta = asarray(nan, dtype=computation_dtype)
out = reshape(array([], dtype=dtype), empty_shape)
if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer):
out = lax.floor(out)
if retstep:
return lax.convert_element_type(out, dtype), delta
else:
return lax.convert_element_type(out, dtype)
@_wraps(np.logspace)
def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, base: ArrayLike = 10.0,
dtype: Optional[DTypeLike] = None, axis: int = 0) -> Array:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace")
return _logspace(start, stop, num, endpoint, base, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, base: ArrayLike = 10.0,
dtype: Optional[DTypeLike] = None, axis: int = 0) -> Array:
"""Implementation of logspace differentiable in start and stop args."""
lax_internal._check_user_dtype_supported(dtype, "logspace")
if dtype is None:
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
_check_arraylike("logspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
lin = linspace(start, stop, num,
endpoint=endpoint, retstep=False, dtype=None, axis=axis)
return lax.convert_element_type(power(base, lin), dtype)
@_wraps(np.geomspace)
def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True,
dtype: Optional[DTypeLike] = None, axis: int = 0) -> Array:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace")
return _geomspace(start, stop, num, endpoint, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True,
dtype: Optional[DTypeLike] = None, axis: int = 0) -> Array:
"""Implementation of geomspace differentiable in start and stop args."""
lax_internal._check_user_dtype_supported(dtype, "geomspace")
if dtype is None:
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
_check_arraylike("geomspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2
signflip = signflip.astype(computation_dtype)
res = signflip * logspace(log10(signflip * start),
log10(signflip * stop), num,
endpoint=endpoint, base=10.0,
dtype=computation_dtype, axis=0)
if axis != 0:
res = moveaxis(res, 0, axis)
return lax.convert_element_type(res, dtype)
@_wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> List[Array]:
_check_arraylike("meshgrid", *xi)
args = [asarray(x) for x in xi]
if not copy:
raise ValueError("jax.numpy.meshgrid only supports copy=True")
if indexing not in ["xy", "ij"]:
raise ValueError(f"Valid values for indexing are 'xy' and 'ij', got {indexing}")
if _any(a.ndim != 1 for a in args):
raise ValueError("Arguments to jax.numpy.meshgrid must be 1D, got shapes "
f"{[a.shape for a in args]}")
if indexing == "xy" and len(args) >= 2:
args[0], args[1] = args[1], args[0]
shape = [1 if sparse else a.shape[0] for a in args]
_a_shape = lambda i, a: [*shape[:i], a.shape[0], *shape[i + 1:]] if sparse else shape
output = [lax.broadcast_in_dim(a, _a_shape(i, a), (i,)) for i, a, in enumerate(args)]
if indexing == "xy" and len(args) >= 2:
output[0], output[1] = output[1], output[0]
return output
@_wraps(np.i0)
@jit
def i0(x: ArrayLike) -> Array:
x_arr, = _promote_args_inexact("i0", x)
if not issubdtype(x_arr.dtype, np.floating):
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}")
x_arr = lax.abs(x_arr)
return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr))
@_wraps(np.ix_)
def ix_(*args: ArrayLike) -> Tuple[Array, ...]:
_check_arraylike("ix", *args)
n = len(args)
output = []
for i, a in enumerate(args):
a = asarray(a)
if len(a.shape) != 1:
msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
raise ValueError(msg.format(a.shape))
if _dtype(a) == bool_:
raise NotImplementedError(
"Boolean arguments to jax.numpy.ix_ are not implemented")
shape = [1] * n
shape[i] = a.shape[0]
if a.size == 0:
# Numpy uses an integer index type for empty arrays.
output.append(lax.full(shape, np.zeros((), np.intp)))
else:
output.append(lax.broadcast_in_dim(a, shape, (i,)))
return tuple(output)
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: Literal[False] = False) -> Array: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
*, sparse: Literal[True]) -> Tuple[Array, ...]: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ...
@_wraps(np.indices)
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]:
dimensions = tuple(
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
for d in dimensions)
N = len(dimensions)
output = []
s = dimensions
for i, dim in enumerate(dimensions):
idx = lax.iota(dtype, dim)
if sparse:
s = (1,)*i + (dim,) + (1,)*(N - i - 1)
output.append(lax.broadcast_in_dim(idx, s, (i,)))
if sparse:
return tuple(output)
return stack(output, 0) if output else array([], dtype=dtype)
_TOTAL_REPEAT_LENGTH_DOC = """\
JAX adds the optional `total_repeat_length` parameter which specifies the total
number of repeat, and defaults to sum(repeats). It must be specified for repeat
to be compilable. If `sum(repeats)` is larger than the specified
`total_repeat_length` the remaining values will be discarded. In the case of
`sum(repeats)` being smaller than the specified target length, the final value
will be repeated.
"""
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
total_repeat_length: Optional[int] = None) -> Array:
_check_arraylike("repeat", a)
core.is_special_dim_size(repeats) or _check_arraylike("repeat", repeats)
if axis is None:
a = ravel(a)
axis = 0
else:
a = asarray(a)
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()")
assert isinstance(axis, int) # to appease mypy
if core.is_special_dim_size(repeats):
if total_repeat_length is not None:
raise ValueError("jnp.repeat with a non-constant `repeats` is supported only "
"when `total_repeat_length` is None")
# If total_repeat_length is not given, use a default.
if total_repeat_length is None:
repeats = core.concrete_or_error(None, repeats,
"When jit-compiling jnp.repeat, the total number of repeats must be static. "
"To fix this, either specify a static value for `repeats`, or pass a static "
"value to `total_repeat_length`.")
# Fast path for when repeats is a scalar.
if np.ndim(repeats) == 0 and ndim(a) != 0:
input_shape = shape(a)
aux_axis = axis if axis < 0 else axis + 1
a = expand_dims(a, aux_axis)
reps: List[DimSize] = [1] * len(shape(a))
reps[aux_axis] = repeats
a = tile(a, reps)
result_shape: List[DimSize] = list(input_shape)
result_shape[axis] *= repeats
return reshape(a, result_shape)
repeats = np.ravel(repeats)
if ndim(a) != 0:
repeats = np.broadcast_to(repeats, [shape(a)[axis]])
total_repeat_length = np.sum(repeats)
else:
repeats = ravel(repeats)
if ndim(a) != 0:
repeats = broadcast_to(repeats, [shape(a)[axis]])
# Special case when a is a scalar.
if ndim(a) == 0:
if shape(repeats) == (1,):
return full([total_repeat_length], a)
else:
raise ValueError('`repeat` with a scalar parameter `a` is only '
'implemented for scalar values of the parameter `repeats`.')
# Special case if total_repeat_length is zero.
if total_repeat_length == 0:
result_shape = list(shape(a))
result_shape[axis] = 0
return reshape(array([], dtype=_dtype(a)), result_shape)
# If repeats is on a zero sized axis, then return the array.
if shape(a)[axis] == 0:
return asarray(a)
# This implementation of repeat avoid having to instantiate a large.
# intermediate tensor.
# Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat.
exclusive_repeats = roll(repeats, shift=1).at[0].set(0)
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3]
scatter_indices = cumsum(exclusive_repeats)
# Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0]
block_split_indicators = zeros([total_repeat_length], dtype=int32)
block_split_indicators = block_split_indicators.at[scatter_indices].add(1)
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
gather_indices = cumsum(block_split_indicators) - 1
return take(a, gather_indices, axis=axis)
@_wraps(np.tri)
def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
return lax_internal._tri(dtype, (N, M), k)
@_wraps(np.tril)
@partial(jit, static_argnames=('k',))
def tril(m: ArrayLike, k: int = 0) -> Array:
_check_arraylike("tril", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
N, M = m_shape[-2:]
mask = tri(N, M, k=k, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
@_wraps(np.triu, update_doc=False)
@partial(jit, static_argnames=('k',))
def triu(m: ArrayLike, k: int = 0) -> Array:
_check_arraylike("triu", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
N, M = m_shape[-2:]
mask = tri(N, M, k=k - 1, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
@_wraps(np.trace, skip_params=['out'])
@partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype'))
def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
dtype: Optional[DTypeLike] = None, out: None = None) -> Array:
_check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
lax_internal._check_user_dtype_supported(dtype, "trace")
a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(np.int_)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int
a = moveaxis(a, (axis1, axis2), (-2, -1))
# Mask out the diagonal and reduce.
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
a, zeros_like(a))
return sum(a, axis=(-2, -1), dtype=dtype)
def _wrap_indices_function(f):
@_wraps(f, update_doc=False)
def wrapper(*args, **kwargs):
args = [core.concrete_or_error(
None, arg, f"argument {i} of jnp.{f.__name__}()")
for i, arg in enumerate(args)]
kwargs = {key: core.concrete_or_error(
None, val, f"argument '{key}' of jnp.{f.__name__}()")
for key, val in kwargs.items()}
return tuple(asarray(x) for x in f(*args, **kwargs))
return wrapper
mask_indices = _wrap_indices_function(np.mask_indices)
def _triu_size(n, m, k):
if k < 0:
return n * m - _triu_size(m, n, (1 - k))
elif k >= m:
return 0
else:
mk = _min(n, m - k)
return mk * (mk + 1) // 2 + mk * (m - k - mk)
@_wraps(np.triu_indices)
def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]:
n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices")
k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices")
m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices")
i, j = nonzero(triu(ones((n, m)), k=k), size=_triu_size(n, m, k))
return i, j
@_wraps(np.tril_indices)
def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]:
n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices")
k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices")
m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices")
i, j = nonzero(tril(ones((n, m)), k=k), size=_triu_size(m, n, -k))
return i, j
@_wraps(np.triu_indices_from)
def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]:
arr_shape = shape(arr)
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@_wraps(np.tril_indices_from)
def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]:
arr_shape = shape(arr)
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@_wraps(np.diag_indices)
def diag_indices(n, ndim=2):
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()")
if n < 0:
raise ValueError("n argument to diag_indices must be nonnegative, got {}"
.format(n))
if ndim < 0:
raise ValueError("ndim argument to diag_indices must be nonnegative, got {}"
.format(ndim))
return (lax.iota(int_, n),) * ndim
@_wraps(np.diag_indices_from)
def diag_indices_from(arr):
_check_arraylike("diag_indices_from", arr)
if not arr.ndim >= 2:
raise ValueError("input array must be at least 2-d")
if len(set(arr.shape)) != 1:
raise ValueError("All dimensions of input must be of equal length")
return diag_indices(arr.shape[0], ndim=arr.ndim)
@_wraps(np.diagonal, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('offset', 'axis1', 'axis2'))
def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
_check_arraylike("diagonal", a)
a_shape = shape(a)
if ndim(a) < 2:
raise ValueError("diagonal requires an array of at least two dimensions.")
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
a = moveaxis(a, (axis1, axis2), (-2, -1))
diag_size = _max(0, _min(a_shape[axis1] + _min(offset, 0),
a_shape[axis2] - _max(offset, 0)))
i = arange(diag_size)
j = arange(_abs(offset), _abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]
@_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC)
def diag(v, k=0):
return _diag(v, operator.index(k))
@partial(jit, static_argnames=('k',))
def _diag(v, k):
_check_arraylike("diag", v)
v_shape = shape(v)
if len(v_shape) == 1:
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
n = v_shape[0] + _abs(k)
v = lax.pad(v, zero(v), ((_max(0, k), _max(0, -k), 0),))
return where(eye(n, k=k, dtype=bool), v, zeros_like(v))
elif len(v_shape) == 2:
return diagonal(v, offset=k)
else:
raise ValueError("diag input must be 1d or 2d")
_SCALAR_VALUE_DOC = """\
This differs from np.diagflat for some scalar values of v,
jax always returns a two-dimensional array, whereas numpy may
return a scalar depending on the type of v.
"""
@_wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC)
def diagflat(v, k=0):
_check_arraylike("diagflat", v)
v = ravel(v)
v_length = len(v)
adj_length = v_length + _abs(k)
res = zeros(adj_length*adj_length, dtype=v.dtype)
i = arange(0, adj_length-_abs(k))
if (k >= 0):
fi = i+k+i*adj_length
else:
fi = i+(i-k)*adj_length
res = res.at[fi].set(v)
res = res.reshape(adj_length, adj_length)
return res
@_wraps(np.trim_zeros)
def trim_zeros(filt, trim='fb'):
filt = core.concrete_or_error(asarray, filt,
"Error arose in the `filt` argument of trim_zeros()")
nz = (filt == 0)
if all(nz):
return empty(0, _dtype(filt))
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
def trim_zeros_tol(filt, tol, trim='fb'):
filt = core.concrete_or_error(asarray, filt,
"Error arose in the `filt` argument of trim_zeros_tol()")
nz = (abs(filt) < tol)
if all(nz):
return empty(0, _dtype(filt))
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
@_wraps(np.append)
@partial(jit, static_argnames=('axis',))
def append(arr, values, axis: Optional[int] = None):
if axis is None:
return concatenate([ravel(arr), ravel(values)], 0)
else:
return concatenate([arr, values], axis=axis)
@_wraps(np.delete)
def delete(arr, obj, axis=None):
_check_arraylike("delete", arr)
if axis is None:
arr = ravel(arr)
axis = 0
axis = _canonicalize_axis(axis, arr.ndim)
# Case 1: obj is a static integer.
try:
obj = operator.index(obj)
obj = _canonicalize_axis(obj, arr.shape[axis])
except TypeError:
pass
else:
idx = tuple(slice(None) for i in range(axis))
return concatenate([arr[idx + (slice(0, obj),)], arr[idx + (slice(obj + 1, None),)]], axis=axis)
# Case 2: obj is a static slice.
if isinstance(obj, slice):
# TODO(jakevdp): we should be able to do this dynamically with care.
indices = np.delete(np.arange(arr.shape[axis]), obj)
return take(arr, indices, axis=axis)
# Case 3: obj is an array
# NB: pass both arrays to check for appropriate error message.
_check_arraylike("delete", arr, obj)
obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()")
if issubdtype(obj.dtype, integer):
# TODO(jakevdp): in theory this could be done dynamically if obj has no duplicates,
# but this would require the complement of lax.gather.
mask = np.ones(arr.shape[axis], dtype=bool)
mask[obj] = False
elif obj.dtype == bool:
if obj.shape != (arr.shape[axis],):
raise ValueError("np.delete(arr, obj): for boolean indices, obj must be one-dimensional "
"with length matching specified axis.")
mask = ~obj
else:
raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.")
return arr[tuple(slice(None) for i in range(axis)) + (mask,)]
@_wraps(np.insert)
def insert(arr, obj, values, axis=None):
_check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
arr = asarray(arr)
values = asarray(values)
if axis is None:
arr = ravel(arr)
axis = 0
axis = core.concrete_or_error(None, axis, "axis argument of jnp.insert()")
axis = _canonicalize_axis(axis, arr.ndim)
if isinstance(obj, slice):
indices = arange(*obj.indices(arr.shape[axis]))
else:
indices = asarray(obj)
if indices.ndim > 1:
raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional "
f"array, or a scalar; got {obj}")
if not np.issubdtype(indices.dtype, np.integer):
if indices.size == 0 and not isinstance(obj, Array):
indices = indices.astype(int)
else:
# Note: np.insert allows boolean inputs but the behavior is deprecated.
raise ValueError("jnp.insert(): index array must be "
f"integer typed; got {obj}")
values = array(values, ndmin=arr.ndim, dtype=arr.dtype, copy=False)
if indices.size == 1:
index = ravel(indices)[0]
if indices.ndim == 0:
values = moveaxis(values, 0, axis)
indices = full(values.shape[axis], index)
n_input = arr.shape[axis]
n_insert = broadcast_shapes(indices.shape, values.shape[axis])[0]
out_shape = list(arr.shape)
out_shape[axis] += n_insert
out = zeros_like(arr, shape=tuple(out_shape))
indices = where(indices < 0, indices + n_input, indices)
indices = clip(indices, 0, n_input)
values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype))
arr_mask = ones(n_input + n_insert, dtype=bool).at[values_ind].set(False)
arr_ind = where(arr_mask, size=n_input)[0]
out = out.at[(slice(None),) * axis + (values_ind,)].set(values)
out = out.at[(slice(None),) * axis + (arr_ind,)].set(arr)
return out
@_wraps(np.apply_along_axis)
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
num_dims = ndim(arr)
axis = _canonicalize_axis(axis, num_dims)
func = lambda arr: func1d(arr, *args, **kwargs)
for i in range(1, num_dims - axis):
func = jax.vmap(func, in_axes=i, out_axes=-1)
for i in range(axis):
func = jax.vmap(func, in_axes=0, out_axes=0)
return func(arr)
@_wraps(np.apply_over_axes)
def apply_over_axes(func, a, axes):
for axis in axes:
b = func(a, axis=axis)
if b.ndim == a.ndim:
a = b
elif b.ndim == a.ndim - 1:
a = expand_dims(b, axis)
else:
raise ValueError("function is not returning an array of the correct shape")
return a
### Tensor contraction operations
@_wraps(np.dot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("dot", a, b)
a, b = _promote_dtypes(a, b)
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
return lax.mul(a, b)
if _max(a_ndim, b_ndim) <= 2:
return lax.dot(a, b, precision=precision)
if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
batch_dims = ((), ())
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
@_wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("matmul", a, b)
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)
a, b = _promote_dtypes(a, b)
a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims = shape(a)[:-2] if a_is_mat else ()
b_batch_dims = shape(b)[:-2] if b_is_mat else ()
num_batch_dims = _max(len(a_batch_dims), len(b_batch_dims))
a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims
b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims
# Dimensions to squeeze from the inputs.
a_squeeze = []
b_squeeze = []
# Positions of batch dimensions in squeezed inputs.
a_batch = []
b_batch = []
# Desired index in final output of each kind of dimension, in the order that
# lax.dot_general will emit them.
idx_batch = []
idx_a_other = [] # other = non-batch, non-contracting.
idx_b_other = []
for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)):
if ba is None:
idx_b_other.append(i)
elif bb is None:
idx_a_other.append(i)
elif core.symbolic_equal_dim(ba, 1):
idx_b_other.append(i)
a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze))
elif core.symbolic_equal_dim(bb, 1):
idx_a_other.append(i)
b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze))
elif core.symbolic_equal_dim(ba, bb):
a_batch.append(len(idx_batch) + len(idx_a_other))
b_batch.append(len(idx_batch) + len(idx_b_other))
idx_batch.append(i)
else:
raise ValueError("Incompatible shapes for matmul arguments: {} and {}"
.format(shape(a), shape(b)))
if a_is_mat: idx_a_other.append(num_batch_dims)
if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat)
perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other]))
a = lax.squeeze(a, tuple(a_squeeze))
b = lax.squeeze(b, tuple(b_squeeze))
out = lax.dot_general(
a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
precision=precision)
return lax.transpose(out, perm)
@_wraps(np.vdot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def vdot(a, b, *, precision=None):
_check_arraylike("vdot", a, b)
if issubdtype(_dtype(a), complexfloating):
a = conj(a)
return dot(a.ravel(), b.ravel(), precision=precision)
@_wraps(np.tensordot, lax_description=_PRECISION_DOC)
def tensordot(a, b, axes=2, *, precision=None):
_check_arraylike("tensordot", a, b)
a_ndim = ndim(a)
b_ndim = ndim(b)
a, b = _promote_dtypes(a, b)
if type(axes) is int:
if axes > _min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
raise TypeError(msg.format(axes, a.shape, b.shape))
contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes))
elif type(axes) in (list, tuple) and len(axes) == 2:
ax1, ax2 = axes
if type(ax1) == type(ax2) == int:
contracting_dims = ((_canonicalize_axis(ax1, a_ndim),),
(_canonicalize_axis(ax2, b_ndim),))
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
if len(ax1) != len(ax2):
msg = "tensordot requires axes lists to have equal length, got {} and {}."
raise TypeError(msg.format(ax1, ax2))
contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1),
tuple(_canonicalize_axis(i, b_ndim) for i in ax2))
else:
msg = ("tensordot requires both axes lists to be either ints, tuples or "
"lists, got {} and {}")
raise TypeError(msg.format(ax1, ax2))
else:
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
"of lists/tuples of ints.")
raise TypeError(msg)
return lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision)
_EINSUM_DOC = _PRECISION_DOC + """\
A tuple ``precision`` does not necessarily map to multiple arguments of ``einsum()``;
rather, the specified ``precision`` is forwarded to each ``dot_general`` call used in
the implementation.
"""
@_wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out'])
def einsum(*operands, out=None, optimize='optimal', precision=None,
_use_xeinsum=False):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
spec = operands[0] if isinstance(operands[0], str) else None
if (_use_xeinsum or spec is not None and '{' in spec):
return jax.named_call(lax.xeinsum, name=spec)(*operands)
optimize = 'optimal' if optimize is True else optimize
# using einsum_call=True here is an internal api for opt_einsum
# Allow handling of shape polymorphism
non_constant_dim_types = {
type(d) for op in operands if not isinstance(op, str)
for d in np.shape(op) if not core.is_constant_dim(d)
}
if not non_constant_dim_types:
contract_path = opt_einsum.contract_path
else:
ty = next(iter(non_constant_dim_types))
contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
operands, contractions = contract_path(
*operands, einsum_call=True, use_blas=True, optimize=optimize)
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
_einsum_computation = jax.named_call(
_einsum, name=spec) if spec is not None else _einsum
return _einsum_computation(operands, contractions, precision)
# Enable other modules to override einsum_contact_path.
# Indexed by the type of the non constant dimension
_poly_einsum_handlers = {} # type: ignore
def _default_poly_einsum_handler(*operands, **kwargs):
dummy = collections.namedtuple('dummy', ['shape', 'dtype'])
dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype)
if hasattr(x, 'dtype') else x for x in operands]
mapping = {id(d): i for i, d in enumerate(dummies)}
out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs)
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
return contract_operands, contractions
@_wraps(np.einsum_path)
def einsum_path(subscripts, *operands, optimize='greedy'):
# using einsum_call=True here is an internal api for opt_einsum
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
def _removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))
@partial(jit, static_argnums=(1, 2))
def _einsum(operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
precision):
operands = list(_promote_dtypes(*operands))
def sum(x, axes):
return lax.reduce(x, np.array(0, x.dtype),
lax.add if x.dtype != bool_ else lax.bitwise_or, axes)
def sum_uniques(operand, names, uniques):
if uniques:
axes = [names.index(name) for name in uniques]
operand = sum(operand, axes)
names = _removechars(names, uniques)
return operand, names
def sum_repeats(operand, names, counts, keep_names):
for name, count in counts.items():
if count > 1:
axes = [i for i, n in enumerate(names) if n == name]
eye = lax_internal._delta(operand.dtype, operand.shape, axes)
if name not in keep_names:
operand = sum(operand * eye, axes)
names = names.replace(name, '')
else:
operand = sum(operand * eye, axes[:-1])
names = names.replace(name, '', count - 1)
return operand, names
def filter_singleton_dims(operand, names, other_shape, other_names):
eq = core.symbolic_equal_dim
keep = [not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
for i, j in enumerate(map(other_names.find, names))]
sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))
return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes)
for operand_indices, contracted_names_set, einstr in contractions:
contracted_names = sorted(contracted_names_set)
input_str, result_names = einstr.split('->')
input_names = input_str.split(',')
# switch on the number of operands to be processed in this loop iteration.
# every case here sets 'operand' and 'names'.
if len(operand_indices) == 1:
operand = operands.pop(operand_indices[0])
names, = input_names
counts = collections.Counter(names)
# sum out unique contracted indices with a single reduce-sum
uniques = [name for name in contracted_names if counts[name] == 1]
operand, names = sum_uniques(operand, names, uniques)
# for every repeated index, do a contraction against an identity matrix
operand, names = sum_repeats(operand, names, counts, result_names)
elif len(operand_indices) == 2:
lhs, rhs = map(operands.pop, operand_indices)
lhs_names, rhs_names = input_names
# handle cases where one side of a contracting or batch dimension is 1
# but its counterpart is not.
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
rhs_names)
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
lhs_names)
lhs_counts = collections.Counter(lhs_names)
rhs_counts = collections.Counter(rhs_names)
# sum out unique contracted indices in lhs and rhs
lhs_uniques = [name for name in contracted_names
if lhs_counts[name] == 1 and rhs_counts[name] == 0]
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
rhs_uniques = [name for name in contracted_names
if rhs_counts[name] == 1 and lhs_counts[name] == 0]
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
# for every repeated index, contract against an identity matrix
lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts,
result_names + rhs_names)
rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts,
result_names + lhs_names)
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
for n in batch_names)
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
# due to opt_einsum
assert jax.config.jax_dynamic_shapes or _all(
name in lhs_names and name in rhs_names and
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
for name in contracted_names), (
"Incompatible reduction dimensions: "
f"lhs.shape={lhs.shape} lhs_names={lhs_names} "
f"rhs.shape={rhs.shape} rhs_names={rhs_names}")
# contract using lax.dot_general
batch_names_str = ''.join(batch_names)
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
deleted_names = batch_names_str + ''.join(contracted_names)
remaining_lhs_names = _removechars(lhs_names, deleted_names)
remaining_rhs_names = _removechars(rhs_names, deleted_names)
# Try both orders of lhs and rhs, in the hope that one of them means we
# don't need an explicit transpose. opt_einsum likes to contract from
# right to left, so we expect (rhs,lhs) to have the best chance of not
# needing a transpose.
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
else:
raise NotImplementedError # if this is actually reachable, open an issue!
# the resulting 'operand' with axis labels 'names' should be a permutation
# of the desired result
assert len(names) == len(result_names) == len(set(names))
assert set(names) == set(result_names)
if names != result_names:
perm = tuple(names.index(name) for name in result_names)
operand = lax.transpose(operand, perm)
operands.append(operand) # used in next iteration
return operands[0]
@_wraps(np.inner, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def inner(a, b, *, precision=None):
if ndim(a) == 0 or ndim(b) == 0:
return a * b
return tensordot(a, b, (-1, -1), precision=precision)
@_wraps(np.outer, skip_params=['out'])
@partial(jit, inline=True)
def outer(a, b, out=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
a, b = _promote_dtypes(a, b)
return ravel(a)[:, None] * ravel(b)[None, :]
@_wraps(np.cross)
@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis'))
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
axis: Optional[int] = None):
if axis is not None:
axisa = axis
axisb = axis
axisc = axis
a = moveaxis(a, axisa, -1)
b = moveaxis(b, axisb, -1)
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError("Dimension must be either 2 or 3 for cross product")
if a.shape[-1] == 2 and b.shape[-1] == 2:
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
a0 = a[..., 0]
a1 = a[..., 1]
a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0)
b0 = b[..., 0]
b1 = b[..., 1]
b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0)
c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0])
return moveaxis(c, 0, axisc)
@_wraps(np.kron)
@jit
def kron(a, b):
a, b = _promote_dtypes(a, b)
if ndim(a) < ndim(b):
a = expand_dims(a, range(ndim(b) - ndim(a)))
elif ndim(b) < ndim(a):
b = expand_dims(b, range(ndim(a) - ndim(b)))
a_reshaped = expand_dims(a, range(1, 2 * ndim(a), 2))
b_reshaped = expand_dims(b, range(0, 2 * ndim(b), 2))
out_shape = tuple(np.multiply(shape(a), shape(b)))
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
@_wraps(np.vander)
@partial(jit, static_argnames=('N', 'increasing'))
def vander(x, N=None, increasing=False):
_check_arraylike("vander", x)
x = asarray(x)
if x.ndim != 1:
raise ValueError("x must be a one-dimensional array")
N = x.shape[0] if N is None else core.concrete_or_error(
operator.index, N, "'N' argument of jnp.vander()")
if N < 0:
raise ValueError("N must be nonnegative")
iota = lax.iota(x.dtype, N)
if not increasing:
iota = lax.sub(_lax_const(iota, N - 1), iota)
return power(x[..., None], expand_dims(iota, tuple(range(x.ndim))))
### Misc
_ARGWHERE_DOC = """\
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified,
the indices of the first ``size`` True elements will be returned; if there are fewer
nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""
@_wraps(np.argwhere,
lax_description=_dedent("""
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.argwhere`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there
are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def argwhere(a, *, size=None, fill_value=None):
result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value)))
if ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))
@_wraps(np.argmax, skip_params=['out'])
def argmax(a, axis: Optional[int] = None, out=None, keepdims=None):
return _argmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _argmax(a, axis: Optional[int] = None, out=None, keepdims=False):
_check_arraylike("argmax", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.")
if axis is None:
dims = list(range(ndim(a)))
a = ravel(a)
axis = 0
else:
dims = [axis]
if a.shape[axis] == 0:
raise ValueError("attempt to get argmax of an empty sequence")
result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
return expand_dims(result, dims) if keepdims else result
@_wraps(np.argmin, skip_params=['out'])
def argmin(a, axis: Optional[int] = None, out=None, keepdims=None):
return _argmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _argmin(a, axis: Optional[int] = None, out=None, keepdims=False):
_check_arraylike("argmin", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.")
if axis is None:
dims = list(range(ndim(a)))
a = ravel(a)
axis = 0
else:
dims = [axis]
if a.shape[axis] == 0:
raise ValueError("attempt to get argmin of an empty sequence")
result = lax.argmin(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
return expand_dims(result, dims) if keepdims else result
_NANARG_DOC = """\
Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise
an error.
"""
@_wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out'])
def nanargmax(a, axis: Optional[int] = None, out : Any = None, keepdims : Optional[bool] = None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.")
return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims'))
def _nanargmax(a, axis: Optional[int] = None, keepdims: bool = False):
_check_arraylike("nanargmax", a)
if not issubdtype(_dtype(a), inexact):
return argmax(a, axis=axis, keepdims=keepdims)
nan_mask = isnan(a)
a = where(nan_mask, -inf, a)
res = argmax(a, axis=axis, keepdims=keepdims)
return where(all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
@_wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out'])
def nanargmin(a, axis: Optional[int] = None, out : Any = None, keepdims : Optional[bool] = None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.")
return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims'))
def _nanargmin(a, axis: Optional[int] = None, keepdims : bool = False):
_check_arraylike("nanargmin", a)
if not issubdtype(_dtype(a), inexact):
return argmin(a, axis=axis, keepdims=keepdims)
nan_mask = isnan(a)
a = where(nan_mask, inf, a)
res = argmin(a, axis=axis, keepdims=keepdims)
return where(all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
@_wraps(np.sort)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
_check_arraylike("sort", a)
if kind != 'quicksort':
warnings.warn("'kind' argument to sort is ignored.")
if order is not None:
raise ValueError("'order' argument to sort is not supported.")
if axis is None:
return lax.sort(a.ravel(), dimension=0)
else:
return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a)))
@_wraps(np.sort_complex)
@jit
def sort_complex(a):
_check_arraylike("sort_complex", a)
a = lax.sort(a, dimension=0)
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
@_wraps(np.lexsort)
@partial(jit, static_argnames=('axis',))
def lexsort(keys, axis=-1):
keys = tuple(keys)
if len(keys) == 0:
raise TypeError("need sequence of keys with len > 0 in lexsort")
if len({shape(key) for key in keys}) > 1:
raise ValueError("all keys need to be the same shape")
if ndim(keys[0]) == 0:
return array(0, dtype=dtypes.canonicalize_dtype(int_))
axis = _canonicalize_axis(axis, ndim(keys[0]))
use_64bit_index = keys[0].shape[axis] >= (1 << 31)
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, shape(keys[0]), axis)
return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]
_ARGSORT_DOC = """
Only :code:`kind='stable'` is supported. Other :code:`kind` values will produce
a warning and be treated as if they were :code:`'stable'`.
"""
@_wraps(np.argsort, lax_description=_ARGSORT_DOC)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
def argsort(a, axis: Optional[int] = -1, kind='stable', order=None):
_check_arraylike("argsort", a)
if kind != 'stable':
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
"are supported.")
if order is not None:
raise ValueError("'order' argument to argsort is not supported.")
if axis is None:
return argsort(a.ravel(), 0)
else:
axis_num = _canonicalize_axis(axis, ndim(a))
use_64bit_index = a.shape[axis_num] >= (1 << 31)
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, shape(a), axis_num)
_, perm = lax.sort_key_val(a, iota, dimension=axis_num)
return perm
@_wraps(np.msort)
def msort(a):
# TODO(jakevdp): remove msort after Feb 2023
warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning)
return sort(a, axis=0)
@_wraps(np.partition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
_check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
raise NotImplementedError("jnp.partition for complex dtype is not implemented.")
axis = _canonicalize_axis(axis, arr.ndim)
kth = _canonicalize_axis(kth, arr.shape[axis])
arr = swapaxes(arr, axis, -1)
bottom = -lax.top_k(-arr, kth + 1)[0]
top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0]
out = lax.concatenate([bottom, top], dimension=arr.ndim - 1)
return swapaxes(out, -1, axis)
@_wraps(np.argpartition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
_check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
axis = _canonicalize_axis(axis, arr.ndim)
kth = _canonicalize_axis(kth, arr.shape[axis])
arr = swapaxes(arr, axis, -1)
bottom_ind = lax.top_k(-arr, kth + 1)[1]
# To avoid issues with duplicate values, we compute the top indices via a proxy
set_to_zero = lambda a, i: a.at[i].set(0)
for _ in range(arr.ndim - 1):
set_to_zero = jax.vmap(set_to_zero)
proxy = set_to_zero(ones(arr.shape), bottom_ind)
top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1]
out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1)
return swapaxes(out, -1, axis)
@partial(jit, static_argnums=(2,))
def _roll(a, shift, axis):
a_shape = shape(a)
if axis is None:
return lax.reshape(_roll(ravel(a), shift, axis=0), a_shape)
shift = asarray(shift)
a_ndim = len(a_shape)
axis = np.asarray(axis)
b_shape = lax.broadcast_shapes(shift.shape, axis.shape, (1,))
if len(b_shape) != 1:
msg = "'shift' and 'axis' arguments to roll must be scalars or 1D arrays"
raise ValueError(msg)
for x, i in zip(broadcast_to(shift, b_shape),
np.broadcast_to(axis, b_shape)):
i = _canonicalize_axis(i, a_ndim)
a_shape_i = array(a_shape[i], dtype=np.int32)
x = remainder(lax.convert_element_type(x, np.int32),
lax.max(a_shape_i, np.int32(1)))
a = lax.concatenate((a, a), i)
a = lax.dynamic_slice_in_dim(a, a_shape_i - x, a_shape[i], axis=i)
return a
@_wraps(np.roll)
def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None):
_check_arraylike("roll", a,)
if isinstance(axis, list):
axis = tuple(axis)
return _roll(a, shift, axis)
@_wraps(np.rollaxis, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis', 'start'))
def rollaxis(a, axis: int, start=0):
_check_arraylike("rollaxis", a)
start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()")
a_ndim = ndim(a)
axis = _canonicalize_axis(axis, a_ndim)
if not (-a_ndim <= start <= a_ndim):
raise ValueError(f"{start=} must satisfy {-a_ndim}<=start<={a_ndim}")
if start < 0:
start += a_ndim
if start > axis:
start -= 1
return moveaxis(a, axis, start)
@_wraps(np.packbits)
@partial(jit, static_argnames=('axis', 'bitorder'))
def packbits(a, axis: Optional[int] = None, bitorder='big'):
_check_arraylike("packbits", a)
if not (issubdtype(_dtype(a), integer) or issubdtype(_dtype(a), bool_)):
raise TypeError('Expected an input array of integer or boolean data type')
if bitorder not in ['little', 'big']:
raise ValueError("'order' must be either 'little' or 'big'")
a = lax.gt(a, _lax_const(a, 0)).astype('uint8')
bits = arange(8, dtype='uint8')
if bitorder == 'big':
bits = bits[::-1]
if axis is None:
a = ravel(a)
axis = 0
a = swapaxes(a, axis, -1)
remainder = a.shape[-1] % 8
if remainder:
a = lax.pad(a, np.uint8(0),
(a.ndim - 1) * [(0, 0, 0)] + [(0, 8 - remainder, 0)])
a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
bits = expand_dims(bits, tuple(range(a.ndim - 1)))
packed = (a << bits).sum(-1).astype('uint8')
return swapaxes(packed, axis, -1)
@_wraps(np.unpackbits)
@partial(jit, static_argnames=('axis', 'count', 'bitorder'))
def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
_check_arraylike("unpackbits", a)
if _dtype(a) != uint8:
raise TypeError("Expected an input array of unsigned byte data type")
if bitorder not in ['little', 'big']:
raise ValueError("'order' must be either 'little' or 'big'")
bits = asarray(1) << arange(8, dtype='uint8')
if bitorder == 'big':
bits = bits[::-1]
if axis is None:
a = ravel(a)
axis = 0
a = swapaxes(a, axis, -1)
unpacked = ((a[..., None] & expand_dims(bits, tuple(range(a.ndim)))) > 0).astype('uint8')
unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count]
return swapaxes(unpacked, axis, -1)
@_wraps(np.take, skip_params=['out'],
lax_description="""
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
""",
extra_params="""
mode : string, default="fill"
Out-of-bounds indexing mode. The default mode="fill" returns invalid values
(e.g. NaN) for out-of bounds indices (see also ``fill_value`` below).
For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`.
fill_value : optional
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices : bool, default=False
If True, the implementation will assume that the indices are unique,
which can result in more efficient execution on some backends.
indices_are_sorted : bool, default=False
If True, the implementation will assume that the indices are sorted in
ascending order, which can lead to more efficient execution on some backends.
""")
def take(a, indices, axis: Optional[int] = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
return _take(a, indices, None if axis is None else operator.index(axis), out,
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value)
@partial(jit, static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value'))
def _take(a, indices, axis: Optional[int] = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
_check_arraylike("take", a, indices)
a = asarray(a)
indices = asarray(indices)
if axis is None:
a = ravel(a)
axis_idx = 0
else:
axis_idx = _canonicalize_axis(axis, ndim(a))
if mode is None or mode == "fill":
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
elif mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
elif mode == "wrap":
indices = mod(indices, _lax_const(indices, a.shape[axis_idx]))
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
elif mode == "clip":
gather_mode = lax.GatherScatterMode.CLIP
else:
raise ValueError(f"Invalid mode '{mode}' for np.take")
index_dims = len(shape(indices))
slice_sizes = list(shape(a))
if slice_sizes[axis_idx] == 0:
if indices.size != 0:
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
return a
if indices.size == 0:
out_shape = (slice_sizes[:axis_idx] + list(indices.shape) +
slice_sizes[axis_idx + 1:])
return full_like(a, 0, shape=out_shape)
slice_sizes[axis_idx] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(
list(range(axis_idx)) +
list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))),
collapsed_slice_dims=(axis_idx,),
start_index_map=(axis_idx,))
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes),
mode=gather_mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
if issubdtype(_dtype(index), np.unsignedinteger):
return index
if core.is_constant_dim(axis_size):
axis_size_val = _lax_const(index, axis_size)
else:
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
_dtype(index))
if isinstance(index, (int, np.integer)):
return lax.add(index, axis_size_val) if index < 0 else index
else:
return lax.select(index < 0, lax.add(index, axis_size_val), index)
TAKE_ALONG_AXIS_DOC = """
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
an optional ``mode`` parameter controlling how out-of-bounds indices should be
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
indexing in JAX.
"""
@_wraps(np.take_along_axis, update_doc=False,
lax_description=TAKE_ALONG_AXIS_DOC)
@partial(jit, static_argnames=('axis', 'mode'))
def take_along_axis(arr, indices, axis: Optional[int],
mode: Optional[Union[str, lax.GatherScatterMode]] = None):
_check_arraylike("take_along_axis", arr, indices)
index_dtype = dtypes.dtype(indices)
if not dtypes.issubdtype(index_dtype, integer):
raise TypeError("take_along_axis indices must be of integer type, got "
f"{str(index_dtype)}")
if axis is None:
if ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
raise ValueError(msg.format(indices.shape))
return take_along_axis(arr.ravel(), indices, 0)
rank = ndim(arr)
if rank != ndim(indices):
msg = "indices and arr must have the same number of dimensions; {} vs. {}"
raise ValueError(msg.format(ndim(indices), ndim(arr)))
axis = _canonicalize_axis(axis, rank)
def replace(tup, val):
lst = list(tup)
lst[axis] = val
return tuple(lst)
use_64bit_index = _any([not core.is_constant_dim(d) or d >= (1 << 31) for d in arr.shape])
index_dtype = dtype(int64 if use_64bit_index else int32)
indices = lax.convert_element_type(indices, index_dtype)
axis_size = arr.shape[axis]
arr_shape = replace(arr.shape, 1)
idx_shape = indices.shape
out_shape = lax.broadcast_shapes(idx_shape, arr_shape)
if axis_size == 0:
return zeros(out_shape, arr.dtype)
index_dims = [i for i, idx in enumerate(idx_shape) if i == axis or not core.symbolic_equal_dim(idx, 1)]
gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,)
gather_indices = []
slice_sizes = []
offset_dims = []
start_index_map = []
collapsed_slice_dims = []
j = 0
for i in range(rank):
if i == axis:
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
j += 1
elif core.symbolic_equal_dim(idx_shape[i], 1):
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
# and avoid forming an iota index.
offset_dims.append(i)
slice_sizes.append(arr_shape[i])
elif core.symbolic_equal_dim(arr_shape[i], 1):
# If the array dimension is 1 but the index dimension is not, we
# broadcast the array dimension to the index dimension by repeatedly
# gathering the first element.
gather_indices.append(zeros(gather_index_shape, dtype=index_dtype))
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
j += 1
else:
# Otherwise, idx_shape[i] == arr_shape[i]. Use an iota index so
# corresponding elements of array and index are gathered.
# TODO(mattjj): next line needs updating for dynamic shapes
iota = lax.broadcasted_iota(index_dtype, gather_index_shape, j)
gather_indices.append(iota)
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
j += 1
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(offset_dims),
collapsed_slice_dims=tuple(collapsed_slice_dims),
start_index_map=tuple(start_index_map))
return lax.gather(arr, gather_indices_arr, dnums, tuple(slice_sizes),
mode="fill" if mode is None else mode)
### Indexing
def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)
# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
try: aval = core.get_aval(idx)
except: pass
else:
if (isinstance(aval, core.DShapedArray) and aval.shape == () and
dtypes.issubdtype(aval.dtype, np.integer) and
not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
isinstance(arr.shape[0], int)):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value)
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(1, 2))
def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value):
idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
y = arr
if fill_value is not None:
core.concrete_or_error(None, fill_value,
"fill_value argument to indexed get()")
if np.ndim(fill_value) != 0:
raise ValueError("fill_value argument to indexed get() must be a scalar")
if isinstance(fill_value, np.ndarray):
fill_value = fill_value.item()
# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
return zeros_like(y, shape=indexer.slice_shape)
# We avoid generating a gather when indexer.gather_indices.size is empty.
if not core.is_empty_shape(indexer.gather_indices.shape):
y = lax.gather(
y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape,
unique_indices=unique_indices or indexer.unique_indices,
indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
mode=mode, fill_value=fill_value)
# Reverses axes with negative strides.
if indexer.reversed_y_dims:
y = lax.rev(y, indexer.reversed_y_dims)
# This adds np.newaxis/None dimensions.
return expand_dims(y, indexer.newaxis_dims)
_Indexer = collections.namedtuple("_Indexer", [
# The expected shape of the slice output.
"slice_shape",
# The slice shape to pass to lax.gather().
"gather_slice_shape",
# The gather indices to use.
"gather_indices",
# A GatherDimensionNumbers object describing the gather to perform.
"dnums",
# Are the gather_indices known to be non-overlapping and/or sorted?
# (In practice, these translate to "there no advanced indices", because
# only advanced indices could lead to index repetition.)
"unique_indices",
"indices_are_sorted",
# Slice dimensions that have negative strides, and so must be reversed after
# the gather.
"reversed_y_dims",
# Keep track of any axes created by `newaxis`. These must be inserted for
# gathers and eliminated for scatters.
"newaxis_dims",
])
def _split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
Used to pass indices into `jit`-ted function.
"""
# Convert list indices to tuples in cases (deprecated by NumPy.)
idx = _eliminate_deprecated_list_indexing(idx)
# Expand any (concrete) boolean indices. We can then use advanced integer
# indexing logic to handle them.
idx = _expand_bool_indices(idx, shape)
leaves, treedef = tree_flatten(idx)
dynamic = [None] * len(leaves)
static = [None] * len(leaves)
for i, x in enumerate(leaves):
if x is Ellipsis:
static[i] = x
elif isinstance(x, slice):
# slice objects aren't hashable.
static[i] = (x.start, x.stop, x.step)
else:
dynamic[i] = x
return treedef, tuple(static), dynamic
def _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
"""Recombines indices that were split by _split_index_for_jit."""
idx = []
for s, d in zip(static_idx, dynamic_idx):
if d is not None:
idx.append(d)
elif isinstance(s, tuple):
idx.append(slice(s[0], s[1], s[2]))
else:
idx.append(s)
return treedef.unflatten(idx)
def _int(aval):
return not aval.shape and issubdtype(aval.dtype, integer)
def _index_to_gather(x_shape, idx, normalize_indices=True):
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
# move the advanced axes to the front.
advanced_axes_are_contiguous = False
advanced_indexes = None
# The positions of the advanced indexing axes in `idx`.
idx_advanced_axes = []
# The positions of the advanced indexes in x's shape.
# collapsed, after None axes have been removed. See below.
x_advanced_axes = None
if _is_advanced_int_indexer(idx):
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
advanced_pairs = (
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
if normalize_indices:
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
advanced_axes_are_contiguous = np.all(np.diff(idx_advanced_axes) == 1)
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.
collapsed_y_axis = 0 # Current axis in y, after collapsing.
# Scatter dimension numbers.
offset_dims = []
collapsed_slice_dims = []
start_index_map = []
use_64bit_index = _any([not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape])
index_dtype = int64 if use_64bit_index else int32
# Gather indices.
# Pairs of (array, start_dim) values. These will be broadcast into
# gather_indices_shape, with the array dimensions aligned to start_dim, and
# then concatenated.
gather_indices = []
gather_indices_shape = []
# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
# the right shape.
slice_shape = []
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
# indices, which the scatter cannot remove itself.
newaxis_dims = []
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims = []
gather_slice_shape = []
for idx_pos, i in enumerate(idx):
# Handle the advanced indices here if:
# * the advanced indices were not contiguous and we are the start.
# * we are at the position of the first advanced index.
if (advanced_indexes is not None and
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
not advanced_axes_are_contiguous and idx_pos == 0)):
advanced_indexes = broadcast_arrays(*advanced_indexes)
shape = advanced_indexes[0].shape
ndim = len(shape)
start_dim = len(gather_indices_shape)
gather_indices += ((lax.convert_element_type(a, index_dtype), start_dim)
for a in advanced_indexes)
gather_indices_shape += shape
start_index_map.extend(x_advanced_axes)
collapsed_slice_dims.extend(x_advanced_axes)
slice_shape.extend(shape)
y_axis += ndim
collapsed_y_axis += ndim
# Per-index bookkeeping for advanced indexes.
if idx_pos in idx_advanced_axes:
x_axis += 1
gather_slice_shape.append(1)
continue
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
# Handle basic int indexes.
if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i):
if core.symbolic_equal_dim(x_shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
i = lax.convert_element_type(i, index_dtype)
gather_indices.append((i, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
x_axis += 1
# Handle np.newaxis (None)
elif i is None:
slice_shape.append(1)
newaxis_dims.append(y_axis)
y_axis += 1
elif isinstance(i, slice):
# Normalize the slice to use None when possible
start, stop, step = i.start, i.stop, i.step
try:
if step is None or core.symbolic_equal_dim(step, 1):
step = None
if step is None:
if start is None or core.symbolic_equal_dim(start, 0):
start = None
if stop is None or (not isinstance(stop, core.Tracer) and
core.greater_equal_dim(stop, x_shape[x_axis])):
stop = None
elif core.symbolic_equal_dim(step, -1):
step = -1
except (TypeError, core.InconclusiveDimensionOperation):
pass
# Handle slice(None) and slice(None, None, -1)
if start is None and stop is None and (
step is None or isinstance(step, int) and step == -1):
if step == -1:
reversed_y_dims.append(collapsed_y_axis)
slice_shape.append(x_shape[x_axis])
gather_slice_shape.append(x_shape[x_axis])
offset_dims.append(collapsed_y_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
# Handle slice index (only static, otherwise an error is raised)
else:
if not _all(_is_slice_element_none_or_constant(elt)
for elt in (start, stop, step)):
msg = ("Array slice indices must have static start/stop/step to be used "
"with NumPy indexing syntax. "
f"Found slice({start}, {stop}, {step}). "
"To index a statically sized "
"array at a dynamic position, try lax.dynamic_slice/"
"dynamic_update_slice (JAX does not support dynamically sized "
"arrays within JIT compiled functions).")
raise IndexError(msg)
if not core.is_constant_dim(x_shape[x_axis]):
msg = ("Cannot use NumPy slice indexing on an array dimension whose "
f"size is not statically known ({x_shape[x_axis]}). "
"Try using lax.dynamic_slice/dynamic_update_slice")
raise IndexError(msg)
start, limit, stride, needs_rev = _static_idx(slice(start, stop, step),
x_shape[x_axis])
if needs_rev:
reversed_y_dims.append(collapsed_y_axis)
if stride == 1:
i = lax.convert_element_type(start, index_dtype)
gather_indices.append((i, len(gather_indices_shape)))
slice_shape.append(limit - start)
gather_slice_shape.append(limit - start)
offset_dims.append(collapsed_y_axis)
start_index_map.append(x_axis)
else:
i = arange(start, limit, stride, dtype=index_dtype)
size = i.shape[0]
slice_shape.append(size)
gather_slice_shape.append(1)
gather_indices.append((i, len(gather_indices_shape)))
gather_indices_shape.append(size)
start_index_map.append(x_axis)
collapsed_slice_dims.append(x_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
else:
if (abstract_i is not None and
not (issubdtype(abstract_i.dtype, integer) or issubdtype(abstract_i.dtype, bool_))):
msg = ("Indexer must have integer or boolean type, got indexer "
"with type {} at position {}, indexer value {}")
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
msg = "Indexing mode not yet supported. Open a feature request!\n{}"
raise IndexError(msg.format(idx))
if len(gather_indices) == 0:
gather_indices_array = np.zeros((0,), dtype=index_dtype)
elif len(gather_indices) == 1:
g, _ = gather_indices[0]
gather_indices_array = lax.expand_dims(g, (g.ndim,))
else:
last_dim = len(gather_indices_shape)
gather_indices_shape.append(1)
gather_indices_array = lax.concatenate([
lax.broadcast_in_dim(g, gather_indices_shape, tuple(range(i, i + g.ndim)))
for g, i in gather_indices],
last_dim)
dnums = lax.GatherDimensionNumbers(
offset_dims = tuple(offset_dims),
collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)),
start_index_map = tuple(start_index_map)
)
return _Indexer(
slice_shape=slice_shape,
newaxis_dims=tuple(newaxis_dims),
gather_slice_shape=gather_slice_shape,
reversed_y_dims=reversed_y_dims,
dnums=dnums,
gather_indices=gather_indices_array,
unique_indices=advanced_indexes is None,
indices_are_sorted=advanced_indexes is None)
def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0
or isinstance(x, (Sequence, slice))
or x is Ellipsis or x is None)
def _eliminate_deprecated_list_indexing(idx):
# "Basic slicing is initiated if the selection object is a non-array,
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this and raises a TypeError.
if not isinstance(idx, tuple):
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray)):
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
# others are converted to arrays, based on a set of somewhat convoluted heuristics
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
# In JAX, we raise an informative TypeError for *all* non-tuple sequences.
if _any(_should_unpack_list_index(i) for i in idx):
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[tuple(seq)]` instead of `arr[seq]`. "
"See https://github.com/google/jax/issues/4564 for more information.")
else:
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[array(seq)]` instead of `arr[seq]`. "
"See https://github.com/google/jax/issues/4564 for more information.")
raise TypeError(msg)
else:
idx = (idx,)
return idx
def _is_boolean_index(i):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
return (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
or isinstance(i, list) and i and _all(_is_scalar(e)
and issubdtype(_dtype(e), np.bool_) for e in i))
def _expand_bool_indices(idx, shape):
"""Converts concrete bool indexes into advanced integer indexes."""
out = []
total_dims = len(shape)
num_ellipsis = _sum(e is Ellipsis for e in idx)
if num_ellipsis > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
elif num_ellipsis == 1:
total_dims = _sum(_ndim(e) if _is_boolean_index(e) else 1 for e in idx
if e is not None and e is not Ellipsis)
ellipsis_offset = 0
for dim_number, i in enumerate(idx):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if _is_boolean_index(i):
if isinstance(i, list):
i = array(i)
abstract_i = core.get_aval(i)
if not type(abstract_i) is ConcreteArray:
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
elif _ndim(i) == 0:
raise TypeError("JAX arrays do not support boolean scalar indices")
else:
i_shape = _shape(i)
start = len(out) + ellipsis_offset
expected_shape = shape[start: start + _ndim(i)]
if i_shape != expected_shape:
raise IndexError("boolean index did not match shape of indexed array in index "
f"{dim_number}: got {i_shape}, expected {expected_shape}")
out.extend(np.where(i))
else:
out.append(i)
if i is Ellipsis:
ellipsis_offset = len(shape) - total_dims - 1
return tuple(out)
def _is_slice_element_none_or_constant(elt):
"""Return True if elt is a constant or None."""
if elt is None: return True
try:
return type(core.get_aval(elt)) is ConcreteArray
except TypeError:
return False
# TODO(mattjj): clean up this logic
def _is_advanced_int_indexer(idx):
"""Returns True if idx should trigger int array indexing, False otherwise."""
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
assert isinstance(idx, tuple)
if _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_scalar(e) and issubdtype(_dtype(e), np.integer) for e in idx):
return False
return _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_int_arraylike(e) for e in idx)
def _is_int_arraylike(x):
"""Returns True if x is array-like with integer dtype, False otherwise."""
return (isinstance(x, int) and not isinstance(x, bool)
or issubdtype(getattr(x, "dtype", None), np.integer)
or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x))
def _is_scalar(x):
"""Checks if a Python or NumPy scalar."""
return np.isscalar(x) or (isinstance(x, (np.ndarray, Array))
and np.ndim(x) == 0)
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis)
if len_without_none > arr_ndim:
raise IndexError(
f"Too many indices for {array_name}: {len_without_none} "
f"non-None/Ellipsis indices for dim {arr_ndim}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:
if next(ellipses, None) is not None:
raise IndexError(
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
colons = (slice(None),) * (arr_ndim - len_without_none)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr_ndim:
colons = (slice(None),) * (arr_ndim - len_without_none)
idx = tuple(idx) + colons
return idx
def _static_idx(idx: slice, size: DimSize):
"""Helper function to compute the static slice start/limit/stride values."""
if isinstance(size, int):
start, stop, step = idx.indices(size)
else:
raise TypeError(size)
if (step < 0 and stop >= start) or (step > 0 and start >= stop):
return 0, 0, 1, False # sliced to size zero
if step > 0:
return start, stop, step, False
else:
k = (start - stop - 1) % (-step)
return stop + k + 1, start + 1, -step, True
@_wraps(np.blackman)
def blackman(M: int) -> Array:
M = core.concrete_or_error(int, M, "M argument of jnp.blackman")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 0.42 - 0.5 * cos(2 * pi * n / (M - 1)) + 0.08 * cos(4 * pi * n / (M - 1))
@_wraps(np.bartlett)
def bartlett(M: int) -> Array:
M = core.concrete_or_error(int, M, "M argument of jnp.bartlett")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 1 - abs(2 * n + 1 - M) / (M - 1)
@_wraps(np.hamming)
def hamming(M: int) -> Array:
M = core.concrete_or_error(int, M, "M argument of jnp.hamming")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 0.54 - 0.46 * cos(2 * pi * n / (M - 1))
@_wraps(np.hanning)
def hanning(M: int) -> Array:
M = core.concrete_or_error(int, M, "M argument of jnp.hanning")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 0.5 * (1 - cos(2 * pi * n / (M - 1)))
@_wraps(np.kaiser)
def kaiser(M: int, beta: ArrayLike) -> Array:
M = core.concrete_or_error(int, M, "M argument of jnp.kaiser")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
alpha = 0.5 * (M - 1)
return i0(beta * sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta)
def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array:
x1, x2 = xs
return any(x2 != 0)
def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]:
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
@_wraps(np.gcd, module='numpy')
@jit
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
_check_arraylike("gcd", x1, x2)
x1, x2 = _promote_dtypes(x1, x2)
if not issubdtype(_dtype(x1), integer):
raise ValueError("Arguments to jax.numpy.gcd must be integers.")
x1, x2 = broadcast_arrays(x1, x2)
gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2)))
return gcd
@_wraps(np.lcm, module='numpy')
@jit
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
_check_arraylike("lcm", x1, x2)
x1, x2 = _promote_dtypes(x1, x2)
x1, x2 = abs(x1), abs(x2)
if not issubdtype(_dtype(x1), integer):
raise ValueError("Arguments to jax.numpy.lcm must be integers.")
d = gcd(x1, x2)
return where(d == 0, _lax_const(d, 0),
multiply(x1, floor_divide(x2, d)))
@_wraps(np.extract)
def extract(condition: ArrayLike, arr: ArrayLike) -> Array:
return compress(ravel(condition), ravel(arr))
@_wraps(np.compress, skip_params=['out'])
def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None,
out: None = None) -> Array:
_check_arraylike("compress", condition, a)
condition_arr = asarray(condition).astype(bool)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
if condition_arr.ndim != 1:
raise ValueError("condition must be a 1D array")
if axis is None:
axis = 0
arr = ravel(a)
else:
arr = moveaxis(a, axis, 0)
condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:]
if any(extra):
raise ValueError("condition contains entries that are out of bounds")
arr = arr[:condition_arr.shape[0]]
return moveaxis(arr[condition_arr], 0, axis)
@_wraps(np.cov)
@partial(jit, static_argnames=('rowvar', 'bias', 'ddof'))
def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
bias: bool = False, ddof: Optional[int] = None,
fweights: Optional[ArrayLike] = None,
aweights: Optional[ArrayLike] = None) -> Array:
if y is not None:
m, y = _promote_args_inexact("cov", m, y)
if y.ndim > 2:
raise ValueError("y has more than 2 dimensions")
else:
m, = _promote_args_inexact("cov", m)
if m.ndim > 2:
raise ValueError("m has more than 2 dimensions") # same as numpy error
X = atleast_2d(m)
if not rowvar and X.shape[0] != 1:
X = X.T
if X.shape[0] == 0:
return array([]).reshape(0, 0)
if y is not None:
y_arr = atleast_2d(y)
if not rowvar and y_arr.shape[0] != 1:
y_arr = y_arr.T
X = concatenate((X, y_arr), axis=0)
if ddof is None:
ddof = 1 if bias == 0 else 0
w: Optional[Array] = None
if fweights is not None:
_check_arraylike("cov", fweights)
if ndim(fweights) > 1:
raise RuntimeError("cannot handle multidimensional fweights")
if shape(fweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and fweights")
if not issubdtype(_dtype(fweights), integer):
raise TypeError("fweights must be integer.")
# Ensure positive fweights; note that numpy raises an error on negative fweights.
w = asarray(abs(fweights))
if aweights is not None:
_check_arraylike("cov", aweights)
if ndim(aweights) > 1:
raise RuntimeError("cannot handle multidimensional aweights")
if shape(aweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and aweights")
# Ensure positive aweights: note that numpy raises an error for negative aweights.
aweights = abs(aweights)
w = asarray(aweights) if w is None else w * asarray(aweights)
avg, w_sum = average(X, axis=1, weights=w, returned=True)
w_sum = w_sum[0]
if w is None:
f = X.shape[1] - ddof
elif ddof == 0:
f = w_sum
elif aweights is None:
f = w_sum - ddof
else:
f = w_sum - ddof * sum(w * aweights) / w_sum
X = X - avg[:, None]
X_T = X.T if w is None else (X * lax.broadcast_to_rank(w, X.ndim)).T
return true_divide(dot(X, X_T.conj()), f).squeeze()
@_wraps(np.corrcoef)
@partial(jit, static_argnames=('rowvar',))
def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -> Array:
_check_arraylike("corrcoef", x)
c = cov(x, y, rowvar)
if len(shape(c)) == 0:
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
return divide(c, c)
d = diag(c)
stddev = sqrt(real(d)).astype(c.dtype)
c = c / stddev[:, None] / stddev[None, :]
real_part = clip(real(c), -1, 1)
if iscomplexobj(c):
complex_part = clip(imag(c), -1, 1)
c = lax.complex(real_part, complex_part)
else:
c = real_part
return c
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a, = _promote_dtypes_inexact(a)
keepdim = []
if issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
if axis is None:
a = ravel(a)
axis = 0
elif isinstance(axis, tuple):
keepdim = list(shape(a))
nd = ndim(a)
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis')
for ax in axis:
keepdim[ax] = 1
keep = set(range(nd)) - set(axis)
# prepare permutation
dimensions = list(range(nd))
for i, s in enumerate(sorted(keep)):
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
axis = _canonicalize_axis(-1, ndim(a))
else:
axis = _canonicalize_axis(axis, ndim(a))
q_shape = shape(q)
q_ndim = ndim(q)
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {shape(q)}")
a_shape = shape(a)
if squash_nans:
a = where(isnan(a), nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = sum(logical_not(isnan(a)), axis=axis, dtype=q.dtype,
keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
a = where(any(isnan(a), axis=axis, keepdims=True), nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(array(a_shape[axis]), lax_internal._dtype(q))
q = lax.mul(q, n - 1)
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.clamp(_lax_const(low, 0), low, n - 1)
high = lax.clamp(_lax_const(high, 0), high, n - 1)
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))
if interpolation == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
result = low_value
elif interpolation == "higher":
result = high_value
elif interpolation == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [shape(q)[0], *keepdim]
result = reshape(result, keepdim)
return lax.convert_element_type(result, a.dtype)
@partial(vectorize, excluded={0, 2, 3})
def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
op = _sort_le_comparator if side == 'left' else _sort_lt_comparator
def body_fun(_, state):
low, high = state
mid = (low + high) // 2
go_left = op(query, sorted_arr[mid])
return (where(go_left, low, mid), where(go_left, mid, high))
n_levels = int(np.ceil(np.log2(len(sorted_arr) + 1)))
init = (dtype(0), dtype(len(sorted_arr)))
return lax.fori_loop(0, n_levels, body_fun, init)[1]
def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64
def _rank(x):
idx = lax.iota(working_dtype, len(x))
return zeros_like(idx).at[argsort(x)].set(idx)
query_flat = query.ravel()
if side == 'left':
index = _rank(lax.concatenate([query_flat, sorted_arr], 0))[:query.size]
else:
index = _rank(lax.concatenate([sorted_arr, query_flat], 0))[sorted_arr.size:]
return lax.reshape(lax.sub(index, _rank(query_flat)), np.shape(query)).astype(dtype)
@_wraps(np.searchsorted, skip_params=['sorter'],
extra_params=_dedent("""
method : str
One of 'scan' (default) or 'sort'. Controls the method used by the implementation; 'scan'
tends to be more performant on CPU (particularly when ``a`` is very large), while
'sort' is often more performant on accelerator backends like GPU and TPU (particularly
when ``v`` is very large)."""))
@partial(jit, static_argnames=('side', 'sorter', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: None = None, *, method: str = 'scan') -> Array:
_check_arraylike("searchsorted", a, v)
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'. "
"Expected one of ['left', 'right'].")
if method not in ['scan', 'sort']:
raise ValueError(f"{method!r} is an invalid value for keyword 'method'. "
"Expected one of ['sort', 'scan'].")
if sorter is not None:
raise NotImplementedError("sorter is not implemented")
if ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
a, v = _promote_dtypes(a, v)
dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
if len(a) == 0:
return zeros_like(v, dtype=dtype)
impl = _searchsorted_via_scan if method == 'scan' else _searchsorted_via_sort
return impl(asarray(a), asarray(v), side, dtype)
@_wraps(np.digitize)
@partial(jit, static_argnames=('right',))
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
_check_arraylike("digitize", x, bins)
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
bins_arr = asarray(bins)
if bins_arr.ndim != 1:
raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}")
if bins_arr.shape[0] == 0:
return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
side = 'right' if not right else 'left'
return where(
bins_arr[-1] >= bins_arr[0],
searchsorted(bins_arr, x, side=side),
len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side)
)
_PIECEWISE_DOC = """\
Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in
`funclist` to be traceable by JAX, as it is implemented via :func:`jax.lax.switch`.
See the :func:`jax.lax.switch` documentation for more information.
"""
@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
funclist: List[Union[ArrayLike, Callable[..., Array]]],
*args, **kw) -> Array:
_check_arraylike("piecewise", x)
nc, nf = len(condlist), len(funclist)
if nf == nc + 1:
funclist = funclist[-1:] + funclist[:-1]
elif nf == nc:
funclist = [0] + list(funclist)
else:
raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
consts = {i: c for i, c in enumerate(funclist) if not callable(c)}
funcs = {i: f for i, f in enumerate(funclist) if callable(f)}
return _piecewise(asarray(x), asarray(condlist, dtype=bool_), consts,
frozenset(funcs.items()), # dict is not hashable.
*args, **kw)
@partial(jit, static_argnames=['funcs'])
def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
funcs: FrozenSet[Tuple[int, Callable[..., Array]]],
*args, **kw) -> Array:
funcdict = dict(funcs)
funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)]
indices = argmax(cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0)
dtype = _dtype(x)
def _call(f):
return lambda x: f(x, *args, **kw).astype(dtype)
def _const(v):
return lambda x: array(v, dtype=dtype)
funclist = [_call(f) if callable(f) else _const(f) for f in funclist]
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("percentile", a, q)
q, = _promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("nanpercentile", a, q)
q = true_divide(q, float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
keepdims=keepdims)
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
_check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
_check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
"""Copy the array and cast to a specified dtype.
This is implemeted via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :meth:`numpy.ndarray.astype` in
some cases. In particular, the details of float-to-int and int-to-float
casts are implementation dependent.
"""
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
lax_internal._check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(arr, dtype)
def _nbytes(arr: ArrayLike) -> int:
"""Total bytes consumed by the elements of the array."""
return size(arr) * _dtype(arr).itemsize
def _itemsize(arr: ArrayLike) -> int:
"""Length of one array element in bytes."""
return _dtype(arr).itemsize
def _clip(number: ArrayLike,
min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, # noqa: F811
out: None = None) -> Array:
"""Return an array whose values are limited to a specified range.
Refer to :func:`jax.numpy.clip` for full documentation."""
return clip(number, a_min=min, a_max=max, out=out)
def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
"""Return a bitwise copy of the array, viewed as a new dtype.
This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`.
If the source and target dtype have the same bitwidth, the result has the same
shape as the input array. If the bitwidth of the target dtype is different
from the source, the size of the last axis of the result is adjusted
accordingly.
>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape
(1, 2, 6)
>>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape
(1, 2, 2)
Conversions involving booleans are not well-defined in all situations. With
regards to the shape of result as explained above, booleans are treated as
having a bitwidth of 8. However, when converting to a boolean array, the input
should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or
may change depending on how the result is used.
This conversion is guaranteed and safe:
>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_)
Array([ True, False, True], dtype=bool)
However, there are no guarantees about the results of any expression involving
a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`.
In particular, the results may change between JAX releases and depending on
the platform. To safely convert such an array to a boolean array, compare it
with `0`:
>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0
Array([ True, True, False], dtype=bool)
"""
if type is not None:
raise NotImplementedError("`type` argument of array.view() is not supported.")
_check_arraylike("view", arr)
arr = asarray(arr)
lax_internal._check_user_dtype_supported(dtype, "view")
dtype = dtypes.canonicalize_dtype(dtype)
if arr.ndim == 0:
if arr.dtype.itemsize != dtype.itemsize:
raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.")
return _view(lax.expand_dims(arr, (0,)), dtype).squeeze()
if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
if arr.dtype == dtype:
return arr
# lax.bitcast_convert_type does not support bool or complex; in these cases we
# cast to a compatible type and recursively call _view for simplicity.
if arr.dtype == bool:
return _view(arr.astype('uint8'), dtype)
if issubdtype(arr.dtype, complexfloating):
new_shape = (*arr.shape[:-1], arr.shape[-1] * 2)
new_dtype = finfo(arr.dtype).dtype
arr = (zeros(new_shape, new_dtype)
.at[..., 0::2].set(arr.real)
.at[..., 1::2].set(arr.imag))
return _view(arr, dtype)
if dtype == bool:
return _view(arr, uint8).astype(bool)
if issubdtype(dtype, complexfloating):
out = _view(arr, finfo(dtype).dtype).astype(dtype)
return out[..., 0::2] + 1j * out[..., 1::2]
# lax.bitcast_convert_type adds or subtracts dimensions depending on the
# relative bitwidths of the dtypes; we account for that with reshapes.
if arr.dtype.itemsize < dtype.itemsize:
factor = dtype.itemsize // arr.dtype.itemsize
arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor)
return lax.bitcast_convert_type(arr, dtype)
if arr.dtype.itemsize > dtype.itemsize:
out = lax.bitcast_convert_type(arr, dtype)
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])
return lax.bitcast_convert_type(arr, dtype)
def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
"consider arr.flatten() instead.")
@_wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
""")
def place(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.place is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
@_wraps(np.put, lax_description="""
Numpy function :func:`numpy.put` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.put`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
""")
def put(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
### add method and operator overloads to arraylike classes
# We add operator overloads to DeviceArray and ShapedArray. These method and
# operator overloads mainly just forward calls to the corresponding lax_numpy
# functions, which can themselves handle instances from any of these classes.
_scalar_types = (int, float, complex, np.generic)
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array)
_rejected_binop_types = (list, tuple, set, dict)
def _defer_to_unrecognized_arg(opchar, binary_op, swap=False):
# Ensure that other array types have the chance to override arithmetic.
def deferring_binary_op(self, other):
if hasattr(other, '__jax_array__'):
other = other.__jax_array__()
args = (other, self) if swap else (self, other)
if isinstance(other, _accepted_binop_types):
return binary_op(*args)
if isinstance(other, _rejected_binop_types):
raise TypeError(f"unsupported operand type(s) for {opchar}: "
f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
return NotImplemented
return deferring_binary_op
def _unimplemented_setitem(self, i, x):
msg = ("'{}' object does not support item assignment. JAX arrays are "
"immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` "
"or another .at[] method: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html")
raise TypeError(msg.format(type(self)))
def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array:
out = round(number, decimals=ndigits or 0)
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
return out.astype(int) if ndigits is None else out
def _copy(self: Array) -> Array:
return self.copy()
def _deepcopy(self: Array, memo: Any) -> Array:
del memo # unused
return self.copy()
# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl)
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
def __array_module__(self, types):
if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types):
return jax.numpy
else:
return NotImplemented
def _compress_method(a: ArrayLike, condition: ArrayLike,
axis: Optional[int] = None, out: None = None) -> Array:
"""Return selected slices of this array along given axis.
Refer to :func:`jax.numpy.compress` for full documentation."""
return compress(condition, a, axis, out)
@_wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast` instead.
""")
def _deprecated_broadcast(*args, **kwargs):
warnings.warn(
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
category=FutureWarning)
return lax.broadcast(*args, **kwargs)
@_wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
""")
def _deprecated_broadcast_in_dim(*args, **kwargs):
warnings.warn(
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
category=FutureWarning)
return lax.broadcast_in_dim(*args, **kwargs)
@_wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.numpy.split` instead.
""")
def _deprecated_split(*args, **kwargs):
warnings.warn(
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
category=FutureWarning)
return split(*args, **kwargs)
@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr: ArrayLike,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]:
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
DeviceArray method here to avoid circular imports.
"""
results: List[Array] = []
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed:
sliced = lax.squeeze(sliced, removed)
results.append(sliced)
return results
# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jit
def _unstack(x: Array) -> List[Array]:
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(device_array.DeviceArray, "_unstack", _unstack)
setattr(ArrayImpl, '_unstack', _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
else:
num_chunks, tail = divmod(x.shape[0], size)
for i in range(num_chunks):
yield lax.dynamic_slice_in_dim(x, i * size, size)
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter)
setattr(ArrayImpl, '_chunk_iter', _chunk_iter)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
# Note: this docstring will appear as the docstring for the `at` property.
"""Helper property for index update functionality.
The ``at`` property provides a functionally pure equivalent of in-place
array modificatons.
In particular:
============================== ================================
Alternate syntax Equivalent In-place expression
============================== ================================
``x = x.at[idx].set(y)`` ``x[idx] = y``
``x = x.at[idx].add(y)`` ``x[idx] += y``
``x = x.at[idx].multiply(y)`` ``x[idx] *= y``
``x = x.at[idx].divide(y)`` ``x[idx] /= y``
``x = x.at[idx].power(y)`` ``x[idx] **= y``
``x = x.at[idx].min(y)`` ``x[idx] = minimum(x[idx], y)``
``x = x.at[idx].max(y)`` ``x[idx] = maximum(x[idx], y)``
``x = x.at[idx].apply(ufunc)`` ``ufunc.at(x, idx)``
``x = x.at[idx].get()`` ``x = x[idx]``
============================== ================================
None of the ``x.at`` expressions modify the original ``x``; instead they return
a modified copy of ``x``. However, inside a :py:func:`~jax.jit` compiled function,
expressions like :code:`x = x.at[idx].set(y)` are guaranteed to be applied in-place.
Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple
indices refer to the same location, all updates will be applied (NumPy would
only apply the last update, rather than applying all updates.) The order
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
Arguments
---------
mode : str
Specify out-of-bound indexing mode. Options are:
- ``"promise_in_bounds"``: (default) The user promises that indices are in bounds.
No additional checking will be performed. In practice, this means that
out-of-bounds indices in ``get()`` will be clipped, and out-of-bounds indices
in ``set()``, ``add()``, etc. will be dropped.
- ``"clip"``: clamp out of bounds indices into valid range.
- ``"drop"``: ignore out-of-bound indices.
- ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value``
argument specifies the value that will be returned.
See :class:`jax.lax.GatherScatterMode` for more details.
indices_are_sorted : bool
If True, the implementation will assume that the indices passed to ``at[]``
are sorted in ascending order, which can lead to more efficient execution
on some backends.
unique_indices : bool
If True, the implementation will assume that the indices passed to ``at[]``
are unique, which can result in more efficient execution on some backends.
fill_value : Any
Only applies to the ``get()`` method: the fill value to return for out-of-bounds
slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for
inexact types, the largest negative value for signed types, the largest positive
value for unsigned types, and ``True`` for booleans.
Examples
--------
>>> x = jnp.arange(5.0)
>>> x
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[2].add(10)
Array([ 0., 1., 12., 3., 4.], dtype=float32)
>>> x.at[10].add(10) # out-of-bounds indices are ignored
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[20].add(10, mode='clip')
Array([ 0., 1., 2., 3., 14.], dtype=float32)
>>> x.at[2].get()
Array(2., dtype=float32)
>>> x.at[20].get() # out-of-bounds indices clipped
Array(4., dtype=float32)
>>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN
Array(nan, dtype=float32)
>>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value
Array(-1., dtype=float32)
"""
__slots__ = ("array",)
def __init__(self, array):
self.array = array
def __getitem__(self, index):
return _IndexUpdateRef(self.array, index)
def __repr__(self):
return f"_IndexUpdateHelper({repr(self.array)})"
Array.at.__doc__ = _IndexUpdateHelper.__doc__
_power_fn = power
_divide_fn = divide
class _IndexUpdateRef:
"""Helper object to call indexed update functions for an (advanced) index.
This object references a source array and a specific indexer into that array.
Methods on this object return copies of the source array that have been
modified at the positions specified by the indexer.
"""
__slots__ = ("array", "index")
def __init__(self, array, index):
self.array = array
self.index = index
def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
def get(self, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
"""Equivalent to ``x[idx]``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexing <numpy.doc.indexing>` ``x[idx]``. This function differs from
the usual array indexing syntax in that it allows additional keyword
arguments ``indices_are_sorted`` and ``unique_indices`` to be passed.
See :mod:`jax.ops` for details.
"""
return _rewriting_take(self.array, self.index,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
def set(self, values, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:`indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def apply(self, func, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
Returns the value of ``x`` that would result from applying the unary
function ``func`` to ``x`` at the given indices. This is similar to
``x.at[idx].set(func(x[idx]))``, but differs in the case of repeated indices:
in ``x.at[idx].apply(func)``, repeated indices result in the function being
applied multiple times.
Note that in the current implementation, ``scatter_apply`` is not compatible
with automatic differentiation.
See :mod:`jax.ops` for details.
"""
def _scatter_apply(x, indices, _, dims, **kwargs):
return lax.scatter_apply(x, indices, func, dims, **kwargs)
return scatter._scatter_update(self.array, self.index,
lax_internal._zero(self.array.dtype),
_scatter_apply,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def add(self, values, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] += y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_add,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def multiply(self, values, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] *= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_mul,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode)
mul = multiply
def divide(self, values, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] /= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] /= y``.
See :mod:`jax.ops` for details.
"""
return _divide_fn(
self.array,
scatter._scatter_update(ones_like(self.array), self.index, values,
lax.scatter_mul,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode))
def power(self, values, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] **= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] **= y``.
See :mod:`jax.ops` for details.
"""
return _power_fn(
self.array,
scatter._scatter_update(ones_like(self.array), self.index, values,
lax.scatter_mul,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode))
def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
mode=None):
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = minimum(x[idx], y)``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_min,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
mode=None):
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = maximum(x[idx], y)``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_max,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
_array_operators = {
"getitem": _rewriting_take,
"setitem": _unimplemented_setitem,
"copy": _copy,
"deepcopy": _deepcopy,
"neg": negative,
"pos": positive,
"eq": _defer_to_unrecognized_arg("==", equal),
"ne": _defer_to_unrecognized_arg("!=", not_equal),
"lt": _defer_to_unrecognized_arg("<", less),
"le": _defer_to_unrecognized_arg("<=", less_equal),
"gt": _defer_to_unrecognized_arg(">", greater),
"ge": _defer_to_unrecognized_arg(">=", greater_equal),
"abs": abs,
"add": _defer_to_unrecognized_arg("+", add),
"radd": _defer_to_unrecognized_arg("+", add, swap=True),
"sub": _defer_to_unrecognized_arg("-", subtract),
"rsub": _defer_to_unrecognized_arg("-", subtract, swap=True),
"mul": _defer_to_unrecognized_arg("*", multiply),
"rmul": _defer_to_unrecognized_arg("*", multiply, swap=True),
"div": _defer_to_unrecognized_arg("/", divide),
"rdiv": _defer_to_unrecognized_arg("/", divide, swap=True),
"truediv": _defer_to_unrecognized_arg("/", true_divide),
"rtruediv": _defer_to_unrecognized_arg("/", true_divide, swap=True),
"floordiv": _defer_to_unrecognized_arg("//", floor_divide),
"rfloordiv": _defer_to_unrecognized_arg("//", floor_divide, swap=True),
"divmod": _defer_to_unrecognized_arg("divmod", divmod),
"rdivmod": _defer_to_unrecognized_arg("divmod", divmod, swap=True),
"mod": _defer_to_unrecognized_arg("%", mod),
"rmod": _defer_to_unrecognized_arg("%", mod, swap=True),
"pow": _defer_to_unrecognized_arg("**", power),
"rpow": _defer_to_unrecognized_arg("**", power, swap=True),
"matmul": _defer_to_unrecognized_arg("@", matmul),
"rmatmul": _defer_to_unrecognized_arg("@", matmul, swap=True),
"and": _defer_to_unrecognized_arg("&", bitwise_and),
"rand": _defer_to_unrecognized_arg("&", bitwise_and, swap=True),
"or": _defer_to_unrecognized_arg("|", bitwise_or),
"ror": _defer_to_unrecognized_arg("|", bitwise_or, swap=True),
"xor": _defer_to_unrecognized_arg("^", bitwise_xor),
"rxor": _defer_to_unrecognized_arg("^", bitwise_xor, swap=True),
"invert": bitwise_not,
"lshift": _defer_to_unrecognized_arg("<<", left_shift),
"rshift": _defer_to_unrecognized_arg(">>", right_shift),
"rlshift": _defer_to_unrecognized_arg("<<", left_shift, swap=True),
"rrshift": _defer_to_unrecognized_arg(">>", right_shift, swap=True),
"round": _operator_round,
}
_array_methods = {
"all": all,
"any": any,
"argmax": argmax,
"argmin": argmin,
"argpartition": argpartition,
"argsort": argsort,
"astype": _astype,
"choose": choose,
"clip": _clip,
"conj": conj,
"conjugate": conjugate,
"compress": _compress_method,
"copy": copy,
"cumprod": cumprod,
"cumsum": cumsum,
"diagonal": diagonal,
"dot": dot,
"flatten": ravel,
"item": _item,
"max": max,
"mean": mean,
"min": min,
"nonzero": nonzero,
"prod": prod,
"ptp": ptp,
"ravel": ravel,
"repeat": repeat,
"reshape": _reshape,
"round": round,
"searchsorted": searchsorted,
"sort": sort,
"squeeze": squeeze,
"std": std,
"sum": sum,
"swapaxes": swapaxes,
"take": take,
"trace": trace,
"transpose": _transpose,
"var": var,
"view": _view,
# Methods exposed in order to avoid circular imports
"_split": split, # used in jacfwd/jacrev
"_multi_slice": _multi_slice, # used in pxla for sharding
# Deprecated methods.
# TODO(jakevdp): remove these after June 2023
"broadcast": _deprecated_broadcast,
"broadcast_in_dim": _deprecated_broadcast_in_dim,
"split": _deprecated_split,
}
_array_properties = {
"flat": _notimplemented_flat,
"T": transpose,
"real": real,
"imag": imag,
"nbytes": _nbytes,
"itemsize": _itemsize,
"at": _IndexUpdateHelper,
}
def _set_shaped_array_attributes(shaped_array):
# Set up operator, method, and property forwarding on Tracer instances
# containing
# ShapedArray avals by following the forwarding conventions for Tracer.
# Forward operators using a single-underscore-prefix naming convention:
for operator_name, function in _array_operators.items():
setattr(shaped_array, f"_{operator_name}", staticmethod(function))
# Forward methods and properties using core.{aval_method, aval_property}:
for method_name, method in _array_methods.items():
setattr(shaped_array, method_name, core.aval_method(method))
for prop_name, prop in _array_properties.items():
setattr(shaped_array, prop_name, core.aval_property(prop))
setattr(shaped_array, "_array_module", staticmethod(__array_module__))
_set_shaped_array_attributes(ShapedArray)
_set_shaped_array_attributes(DShapedArray)
def _set_device_array_base_attributes(device_array, include=None, exclude=None):
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
def maybe_setattr(attr_name, target):
if exclude is not None and attr_name in exclude:
return
if not include or attr_name in include:
setattr(device_array, attr_name, target)
for operator_name, function in _array_operators.items():
maybe_setattr(f"__{operator_name}__", function)
for method_name, method in _array_methods.items():
maybe_setattr(method_name, method)
for prop_name, prop in _array_properties.items():
maybe_setattr(prop_name, property(prop))
_set_device_array_base_attributes(device_array.DeviceArray)
_set_device_array_base_attributes(ArrayImpl, exclude={'__getitem__'})
def _set_device_array_attributes(device_array):
setattr(device_array, "__array_module__", __array_module__)
for t in device_array.device_array_types:
_set_device_array_attributes(t)
_set_device_array_attributes(pxla._ShardedDeviceArray)
_set_device_array_attributes(pmap_lib.ShardedDeviceArray)
_set_device_array_attributes(ArrayImpl)