rocm_jax/jax/numpy/lax_numpy.py
jax authors d43d5d9035 Merge pull request #4371 from jakevdp:moveaxis-fix
PiperOrigin-RevId: 332974804
2020-09-21 17:44:35 -07:00

4823 lines
161 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
NumPy operations are implemented in Python in terms of the primitive operations
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
implemented in terms of :mod:`jax.lax` operations, we do not need to define
transformation rules such as gradient or batching rules. Instead,
transformations for NumPy primitives can be derived from the transformation
rules for the underlying :code:`lax` primitives.
"""
import builtins
import collections
import operator
import os
import types
from typing import Sequence, Set, Tuple, Union
from textwrap import dedent as _dedent
import warnings
import numpy as np
import opt_einsum
import jax
from jax import jit, custom_jvp
from .vectorize import vectorize
from ._util import _wraps
from .. import core
from .. import dtypes
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from ..config import flags, config
from ..interpreters.xla import DeviceArray
from ..interpreters.masking import Poly
from .. import lax
from ..lax.lax import _device_put_raw
from .. import ops
from ..util import (partial, unzip2, prod as _prod,
subvals, safe_zip, canonicalize_axis as _canonicalize_axis)
from ..tree_util import tree_leaves, tree_flatten
FLAGS = flags.FLAGS
flags.DEFINE_enum(
'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
enum_values=['allow', 'warn', 'raise'],
help=
'Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").')
newaxis = None
# Common docstring additions:
_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, or any ``jax.lax.Precision`` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``).
"""
# We replace some builtin names to follow Numpy's API, so we capture here.
_abs = builtins.abs
_all = builtins.all
_any = builtins.any
_max = builtins.max
_min = builtins.min
_sum = builtins.sum
_divmod = builtins.divmod
# NumPy constants
pi = np.pi
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
NINF = np.NINF
PZERO = np.PZERO
NZERO = np.NZERO
nan = np.nan
# And some numpy utility functions
set_printoptions = np.set_printoptions
# We want isinstance(x, np.ndarray) checks in user code to work with the our
# array-like types, including DeviceArray and UnshapedArray (i.e. the abstract
# array base class). We can override the isinstance behavior directly, without
# having the complexity of multiple inheritance on those classes, by defining
# the ndarray class to have a metaclass with special __instancecheck__ behavior.
_arraylike_types = (np.ndarray, UnshapedArray, DeviceArray)
class _ArrayMeta(type(np.ndarray)): # type: ignore
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self, instance):
try:
return isinstance(instance.aval, _arraylike_types)
except AttributeError:
return isinstance(instance, _arraylike_types)
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
iscomplexobj = np.iscomplexobj
shape = _shape = np.shape
ndim = _ndim = np.ndim
size = np.size
_dtype = dtypes.result_type
# At present JAX doesn't have a reason to distinguish between scalars and arrays
# in its object system. Further, we want JAX scalars to have the same type
# promotion behaviors as JAX arrays. Rather than introducing a new type of JAX
# scalar object with JAX promotion behaviors, instead we make the JAX scalar
# types return JAX arrays when instantiated.
class _ScalarMeta(type):
def __hash__(self):
return hash(self.dtype.type)
def __eq__(self, other):
return id(self) == id(other) or self.dtype.type == other
def __ne__(self, other):
return not (self == other)
def __call__(self, x):
return array(x, dtype=self.dtype)
def _make_scalar_type(np_scalar_type):
return _ScalarMeta(np_scalar_type.__name__, (object,),
{"dtype": np.dtype(np_scalar_type)})
bool_ = _make_scalar_type(np.bool_)
uint8 = _make_scalar_type(np.uint8)
uint16 = _make_scalar_type(np.uint16)
uint32 = _make_scalar_type(np.uint32)
uint64 = _make_scalar_type(np.uint64)
int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)
float64 = double = _make_scalar_type(np.float64)
complex64 = csingle = _make_scalar_type(np.complex64)
complex128 = cdouble = _make_scalar_type(np.complex128)
int_ = int32 if dtypes.int_ == np.int32 else int64
float_ = float32 if dtypes.float_ == np.float32 else float64
complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
number = np.number
inexact = np.inexact
complexfloating = np.complexfloating
floating = np.floating
integer = np.integer
signedinteger = np.signedinteger
unsignedinteger = np.unsignedinteger
flexible = np.flexible
character = np.character
object_ = np.object_
iinfo = dtypes.iinfo
dtype = np.dtype
can_cast = dtypes.can_cast
issubsctype = dtypes.issubsctype
promote_types = dtypes.promote_types
ComplexWarning = np.ComplexWarning
array_str = np.array_str
array_repr = np.array_repr
save = np.save
savez = np.savez
load = np.load
### utility functions
_DEFAULT_TYPEMAP = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
np.complex_: complex_
}
def _np_array(obj, dtype=None, **kwargs):
"""Return a properly-typed numpy array.
`_np_array(obj, **kwds)` is equivalent to `np.array(obj, **kwds)`, with the
exception that when obj.dtype is not defined and dtype is not specified, it
uses Jax's default dtypes.
"""
arr = np.array(obj, dtype=dtype, **kwargs)
obj_dtype = getattr(obj, 'dtype', None)
arr_dtype = np.dtype(arr.dtype).type
if dtype is None and obj_dtype is None and arr_dtype in _DEFAULT_TYPEMAP:
arr = arr.astype(_DEFAULT_TYPEMAP[arr_dtype])
return arr
_np_asarray = partial(_np_array, copy=False)
def _promote_shapes(fun_name, *args):
"""Prepend implicit leading singleton dimensions for Numpy broadcasting."""
if len(args) < 2:
return args
else:
shapes = [shape(arg) for arg in args]
nonscalar_ranks = [len(shp) for shp in shapes if shp]
if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1:
return args
else:
if FLAGS.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
if FLAGS.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif FLAGS.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
else:
to_dtype = result_type(*args)
return [lax.convert_element_type(x, to_dtype) for x in args]
def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype = _to_inexact_dtype(result_type(*args))
return [lax.convert_element_type(x, to_dtype) for x in args]
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_)
def _complex_elem_type(dtype):
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
def _result_dtype(op, *args):
"""Compute result dtype of applying op to arguments with given dtypes."""
args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args]
return _dtype(op(*args))
def _arraylike(x): return isinstance(x, ndarray) or isscalar(x)
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike (ndarray or scalar)."""
if _any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not _arraylike(arg))
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def _promote_args(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
def _promote_args_inexact(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
_check_arraylike(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
def _constant_like(x, const):
return np.array(const, dtype=_dtype(x))
### implementations of numpy functions in terms of lax
@_wraps(np.fmin)
def fmin(x1, x2):
return where((x1 < x2) | isnan(x2), x1, x2)
@_wraps(np.fmax)
def fmax(x1, x2):
return where((x1 > x2) | isnan(x2), x1, x2)
@_wraps(np.finfo)
def finfo(dtype):
return dtypes.finfo(dtype)
@_wraps(np.issubdtype)
def issubdtype(arg1, arg2):
return dtypes.issubdtype(arg1, arg2)
@_wraps(np.isscalar)
def isscalar(element):
return dtypes.is_python_scalar(element) or np.isscalar(element)
iterable = np.iterable
@_wraps(np.result_type)
def result_type(*args):
return dtypes.result_type(*args)
def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
if promote_to_inexact:
def fn(x):
x = lax.convert_element_type(x, _to_inexact_dtype(_dtype(x)))
return lax_fn(x)
else:
fn = lambda x: lax_fn(x)
if lax_doc:
doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
if promote_to_inexact:
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
else:
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
if lax_doc:
doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
return _wraps(numpy_fn)(fn)
if lax_doc:
doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: x)
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
log = _one_to_one_unop(np.log, lax.log, True)
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
sin = _one_to_one_unop(np.sin, lax.sin, True)
cos = _one_to_one_unop(np.cos, lax.cos, True)
tan = _one_to_one_unop(np.tan, lax.tan, True)
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
arccosh = _one_to_one_unop(np.arccosh, lax.acosh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(np.minimum, lax.min)
maximum = _one_to_one_binop(np.maximum, lax.max)
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)
def _comparison_op(numpy_fn, lax_fn):
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
if issubdtype(_dtype(x1), complexfloating):
rx = lax.real(x1)
ry = lax.real(x2)
return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
lax_fn(rx, ry))
return lax_fn(x1, x2)
return _wraps(numpy_fn)(fn)
greater_equal = _comparison_op(np.greater_equal, lax.ge)
greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)
def _logical_op(np_op, bitwise_op):
@_wraps(np_op, update_doc=False)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x))
for x in args)
return bitwise_op(*_promote_args(np_op.__name__, *args))
return op
logical_and = _logical_op(np.logical_and, lax.bitwise_and)
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
@_wraps(np.right_shift)
def right_shift(x1, x2):
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@_wraps(np.absolute)
def absolute(x):
dt = _dtype(x)
return x if dt == bool_ or issubdtype(dt, unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs)(absolute)
@_wraps(np.rint)
def rint(x):
dtype = _dtype(x)
if issubdtype(dtype, integer):
return lax.convert_element_type(x, float_)
if issubdtype(dtype, complexfloating):
return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
return _round_to_nearest_even(x)
@_wraps(np.sign)
def sign(x):
dtype = _dtype(x)
if issubdtype(dtype, complexfloating):
re = lax.real(x)
return lax.complex(
lax.sign(where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
return lax.sign(x)
@_wraps(np.copysign)
def copysign(x1, x2):
if issubdtype(_dtype(x1), complexfloating) or issubdtype(_dtype(x2), complexfloating):
raise TypeError("copysign does not support complex-valued inputs")
x1, x2 = _promote_args_inexact("copysign", x1, x2)
return where(signbit(x2), -lax.abs(x1), lax.abs(x1))
@_wraps(np.true_divide)
def true_divide(x1, x2):
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)
divide = true_divide
@_wraps(np.floor_divide)
def floor_divide(x1, x2):
x1, x2 = _promote_args("floor_divide", x1, x2)
dtype = _dtype(x1)
if issubdtype(dtype, integer):
quotient = lax.div(x1, x2)
select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
# TODO(mattjj): investigate why subtracting a scalar was causing promotion
return where(select, quotient - np.array(1, _dtype(quotient)), quotient)
elif issubdtype(dtype, complexfloating):
x1r = lax.real(x1)
x1i = lax.imag(x1)
x2r = lax.real(x2)
x2i = lax.imag(x2)
which = lax.ge(lax.abs(x2r), lax.abs(x2i))
rat1 = where(which, lax._const(x2i, 1), lax.div(x2r, x2i))
rat2 = where(which, lax.div(x2i, x2r), lax._const(x2i, 1))
out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
return lax.convert_element_type(out, dtype)
else:
return _float_divmod(x1, x2)[0]
@_wraps(np.divmod)
def divmod(x1, x2):
x1, x2 = _promote_args("divmod", x1, x2)
if issubdtype(_dtype(x1), integer):
return floor_divide(x1, x2), remainder(x1, x2)
else:
return _float_divmod(x1, x2)
def _float_divmod(x1, x2):
# see float_divmod in floatobject.c of CPython
mod = lax.rem(x1, x2)
div = lax.div(lax.sub(x1, mod), x2)
ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
mod = lax.select(ind, mod + x2, mod)
div = lax.select(ind, div - _constant_like(div, 1), div)
return lax.round(div), mod
@_wraps(np.power)
def power(x1, x2):
# Special case for small positive integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(x2, int):
return lax.integer_pow(x1, x2)
x1, x2 = _promote_args(np.power, x1, x2)
dtype = _dtype(x1)
if not issubdtype(dtype, integer):
return lax.pow(x1, x2)
# Integer power => use binary exponentiation.
# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
acc = ones(shape(x1), dtype=dtype)
for _ in range(bits):
acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)),
lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, _constant_like(x2, 1))
return acc
@custom_jvp
@_wraps(np.logaddexp)
def logaddexp(x1, x2):
x1, x2 = _promote_shapes("logaddexp", *_promote_dtypes_inexact(x1, x2))
amax = lax.max(x1, x2)
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = (t1 * exp(_replace_inf(x1) - _replace_inf(primal_out)) +
t2 * exp(_replace_inf(x2) - _replace_inf(primal_out)))
return primal_out, tangent_out
def _replace_inf(x):
return lax.select(isposinf(x), zeros_like(x), x)
@custom_jvp
@_wraps(np.logaddexp2)
def logaddexp2(x1, x2):
x1, x2 = _promote_shapes("logaddexp2", *_promote_dtypes_inexact(x1, x2))
amax = lax.max(x1, x2)
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(-lax.abs(delta))),
_constant_like(x1, np.log(2)))))
@logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = (t1 * 2 ** (_replace_inf(x1) - _replace_inf(primal_out)) +
t2 * 2 ** (_replace_inf(x2) - _replace_inf(primal_out)))
return primal_out, tangent_out
@_wraps(np.log2)
def log2(x):
x, = _promote_dtypes_inexact(x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10)
def log10(x):
x, = _promote_dtypes_inexact(x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2)
def exp2(x):
x, = _promote_dtypes_inexact(x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
@_wraps(np.signbit)
def signbit(x):
x, = _promote_shapes("signbit", x)
dtype = _dtype(x)
if issubdtype(dtype, integer):
return lax.lt(x, _constant_like(x, 0))
elif issubdtype(dtype, bool_):
return full_like(x, False, dtype=bool_)
elif not issubdtype(dtype, floating):
raise ValueError(
"jax.numpy.signbit is not well defined for %s" % dtype)
# TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
# F32.
if dtype == bfloat16:
dtype = float32
x = lax.convert_element_type(x, float32)
info = finfo(dtype)
if info.bits == 16:
int_type = np.int16
elif info.bits == 32:
int_type = np.int32
elif info.bits == 64:
int_type = np.int64
else:
raise NotImplementedError(
"jax.numpy.signbit only supports 16, 32, and 64-bit types.")
x = lax.bitcast_convert_type(x, int_type)
return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
@_wraps(np.trapz)
def trapz(y, x=None, dx=1.0, axis=-1):
y = moveaxis(y, axis, -1)
if x is not None:
if ndim(x) == 1:
dx = diff(x)
else:
dx = moveaxis(diff(x, axis=axis), axis, -1)
return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1)
@_wraps(np.trunc)
def trunc(x):
return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x))
def _conv(x, y, mode, op, precision):
if issubdtype(x.dtype, complexfloating) or issubdtype(y.dtype, complexfloating):
raise NotImplementedError(f"{op}() does not support complex inputs")
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = _promote_dtypes_inexact(x, y)
if len(x) == 0 or len(y) == 0:
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
out_order = slice(None)
if len(x) < len(y):
x, y = y, x
if op == "correlate":
out_order = slice(None, None, -1)
if op == 'convolve':
y = y[::-1]
if mode == 'valid':
padding = [(0, 0)]
elif mode == 'same':
padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)]
elif mode == 'full':
padding = [(y.shape[0] - 1, y.shape[0] - 1)]
else:
raise ValueError("mode must be one of ['full', 'same', 'valid']")
result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
padding, precision=precision)
return result[0, 0, out_order]
@_wraps(np.convolve, lax_description=_PRECISION_DOC)
def convolve(a, v, mode='full', *, precision=None):
return _conv(a, v, mode, 'convolve', precision)
@_wraps(np.correlate, lax_description=_PRECISION_DOC)
def correlate(a, v, mode='valid', *, precision=None):
return _conv(a, v, mode, 'correlate', precision)
def _normalize_float(x):
info = finfo(_dtype(x))
cond = lax.abs(x) < info.tiny
x1 = where(cond, x * (1 << info.nmant), x)
x2 = where(cond,
full_like(x, -info.nmant, dtype=np.int32),
zeros_like(x, dtype=np.int32))
return lax.convert_element_type(x1, _dtype(x)), x2
_INT_DTYPES = {
16: np.int16,
32: np.int32,
64: np.int64,
}
@_wraps(np.ldexp)
@jit
def ldexp(x1, x2):
dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
x1, x2 = _promote_shapes("ldexp", x1, x2)
x1 = lax.convert_element_type(x1, dtype)
info = finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
int_type = _INT_DTYPES[info.bits]
x, e = _normalize_float(x1)
x2 += lax.convert_element_type(e, np.int32)
x = lax.bitcast_convert_type(x, int_type)
x2 += ((x >> info.nmant) & mask) - bias
# find underflow/overflow before denormalization
underflow_cond = x2 < -(bias + info.nmant)
overflow_cond = x2 > bias
m = ones_like(x, dtype=dtype)
# denormals
cond = x2 < -bias + 1
x2 = where(cond, x2 + info.nmant, x2)
m = where(cond, m / (1 << info.nmant), m)
x2 = lax.convert_element_type(x2, np.int32)
x &= ~(mask << info.nmant)
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)
x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
# underflow
x = where(underflow_cond, zeros_like(x, dtype=dtype), x)
# overflow
x = where(overflow_cond, lax.sign(x1) * full_like(x, np.inf), x)
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
return where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
@_wraps(np.frexp)
@jit
def frexp(x):
x = asarray(x)
if issubdtype(x.dtype, complexfloating):
raise TypeError("frexp does not support complex-valued inputs")
elif not issubdtype(x.dtype, floating):
x = lax.convert_element_type(x, float_)
dtype = _dtype(x)
info = finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
int_type = _INT_DTYPES[info.bits]
x1, x2 = _normalize_float(x)
x1 = lax.bitcast_convert_type(x1, int_type)
x2 += ((x1 >> info.nmant) & mask) - bias + 1
x1 &= ~(mask << info.nmant)
x1 |= (bias - 1) << info.nmant
x1 = lax.bitcast_convert_type(x1, dtype)
cond = isinf(x) | isnan(x) | (x == 0)
x2 = where(cond, zeros_like(x2), x2)
return where(cond, x, x1), lax.convert_element_type(x2, int32)
@_wraps(np.remainder)
def remainder(x1, x2):
x1, x2 = _promote_args("remainder", x1, x2)
zero = _constant_like(x1, 0)
trunc_mod = lax.rem(x1, x2)
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
do_plus = lax.bitwise_and(
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod)(remainder)
@_wraps(np.fmod)
def fmod(x1, x2):
if issubdtype(_dtype(x1, x2), integer):
x2 = where(x2 == 0, 1, x2)
return lax.rem(*_promote_args(np.fmod, x1, x2))
@_wraps(np.cbrt)
def cbrt(x):
x, = _promote_dtypes_inexact(x)
return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.))
@_wraps(np.square)
def square(x): return lax.integer_pow(x, 2)
@_wraps(np.deg2rad)
def deg2rad(x):
x, = _promote_dtypes_inexact(x)
return lax.mul(x, lax._const(x, pi / 180))
@_wraps(np.rad2deg)
def rad2deg(x):
x, = _promote_dtypes_inexact(x)
return lax.mul(x, lax._const(x, 180 / pi))
degrees = rad2deg
radians = deg2rad
@_wraps(np.histogram_bin_edges)
def histogram_bin_edges(a, bins=10, range=None, weights=None):
if isinstance(bins, str):
raise NotImplementedError("string values for `bins` not implemented.")
a = ravel(a)
b = array(bins)
if b.ndim == 1:
return b
if range is None:
range = (a.min(), a.max())
assert len(range) == 2
range = asarray(range)
range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
where(ptp(range) == 0, range[1] + 0.5, range[1]))
dtype = _dtype(a)
if issubdtype(dtype, integer):
dtype = promote_types(dtype, float32)
return linspace(range[0], range[1], bins + 1, dtype=dtype)
@_wraps(np.histogram)
def histogram(a, bins=10, range=None, weights=None, density=None):
if weights is not None and a.shape != weights.shape:
raise ValueError("weights should have the same shape as a.")
a = ravel(a)
if weights is not None:
weights = ravel(weights)
else:
weights = ones_like(a)
bin_edges = histogram_bin_edges(a, bins, range, weights)
bin_idx = searchsorted(bin_edges, a, side='right')
bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx)
counts = bincount(bin_idx, weights, length=len(bin_edges))[1:]
if density:
bin_widths = diff(bin_edges)
counts = counts / bin_widths / counts.sum()
return counts, bin_edges
@_wraps(np.heaviside)
def heaviside(x1, x2):
x1, x2 = _promote_dtypes_inexact(x1, x2)
zero = lax._const(x1, 0)
return where(lax.lt(x1, zero), zero,
where(lax.gt(x1, zero), lax._const(x1, 1), x2))
@_wraps(np.hypot)
def hypot(x1, x2):
x1, x2 = _promote_dtypes_inexact(x1, x2)
return lax.sqrt(x1*x1 + x2*x2)
@_wraps(np.reciprocal)
def reciprocal(x):
x, = _promote_dtypes_inexact(x)
return lax.integer_pow(x, -1)
@_wraps(np.sinc, update_doc=False)
def sinc(x):
x, = _promote_dtypes_inexact(x)
eq_zero = lax.eq(x, lax._const(x, 0))
safe_x = where(eq_zero, lax._const(x, 0), x)
pi_x = lax.mul(lax._const(x, pi), safe_x)
return where(eq_zero,
lax._const(x, 1), lax.div(lax.sin(pi_x), pi_x))
@_wraps(np.transpose)
def transpose(a, axes=None):
axes = np.arange(ndim(a))[::-1] if axes is None else axes
return lax.transpose(a, axes)
@_wraps(np.rot90)
def rot90(m, k=1, axes=(0, 1)):
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, m.ndim)
ax2 = _canonicalize_axis(ax2, m.ndim)
if ax1 == ax2:
raise ValueError("Axes must be different") # same as numpy error
k = k % 4
if k == 0:
return m
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = list(range(m.ndim))
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
if k == 1:
return transpose(flip(m, ax2), perm)
else:
return flip(transpose(m, perm), ax2)
@_wraps(np.flip)
def flip(m, axis=None):
if axis is None:
return lax.rev(m, list(range(len(m.shape))))
return lax.rev(m, [_canonicalize_axis(axis, len(m.shape))])
@_wraps(np.fliplr)
def fliplr(m):
return flip(m, 1)
@_wraps(np.flipud)
def flipud(m):
return flip(m, 0)
@_wraps(np.conjugate)
def conjugate(x):
return lax.conj(x) if iscomplexobj(x) else x
conj = conjugate
@_wraps(np.imag)
def imag(val):
return lax.imag(val) if iscomplexobj(val) else zeros_like(val)
@_wraps(np.real)
def real(val):
return lax.real(val) if iscomplexobj(val) else val
@_wraps(np.iscomplex)
def iscomplex(x):
i = imag(x)
return lax.ne(i, lax._const(i, 0))
@_wraps(np.isreal)
def isreal(x):
i = imag(x)
return lax.eq(i, lax._const(i, 0))
@_wraps(np.angle)
def angle(z):
re = real(z)
im = imag(z)
dtype = _dtype(re)
if not issubdtype(dtype, inexact) or (
issubdtype(_dtype(z), floating) and ndim(z) == 0):
dtype = dtypes.canonicalize_dtype(float_)
re = lax.convert_element_type(re, dtype)
im = lax.convert_element_type(im, dtype)
return lax.atan2(im, re)
@_wraps(np.diff)
def diff(a, n=1, axis=-1):
_check_arraylike("diff", a)
if n == 0:
return a
if n < 0:
raise ValueError(f"order must be non-negative but got {n}")
if ndim(a) == 0:
raise ValueError(f"diff requires input that is at least one dimensional; got {a}")
nd = a.ndim
slice1 = [slice(None)] * nd
slice2 = [slice(None)] * nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
op = not_equal if a.dtype == np.bool_ else subtract
for _ in range(n):
a = op(a[slice1], a[slice2])
return a
_EDIFF1D_DOC = """\
Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
loses precision.
"""
@_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC)
def ediff1d(ary, to_end=None, to_begin=None):
ary = ravel(asarray(ary))
result = lax.sub(ary[1:], ary[:-1])
if to_begin is not None:
result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result))
if to_end is not None:
result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype))))
return result
@partial(jit, static_argnums=2)
def _gradient(a, varargs, axis):
def gradient_along_axis(a, h, axis):
sliced = partial(lax.slice_in_dim, a, axis=axis)
a_grad = concatenate((
(sliced(1, 2) - sliced(0, 1)), # upper edge
(sliced(2, None) - sliced(None, -2)) * 0.5, # inner
(sliced(-1, None) - sliced(-2, -1)), # lower edge
), axis)
return a_grad / h
if axis is None:
axis = range(a.ndim)
else:
if isinstance(axis, int):
axis = (axis,)
if not isinstance(axis, tuple) and not isinstance(axis, list):
raise ValueError("Give `axis` either as int or iterable")
elif len(axis) == 0:
return []
axis = [_canonicalize_axis(i, a.ndim) for i in axis]
if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
len_axes = len(axis)
n = len(varargs)
if n == 0 or varargs is None:
# no spacing
dx = [1.0] * len_axes
elif n == 1:
# single value for all axes
dx = varargs * len_axes
elif n == len_axes:
dx = varargs
else:
TypeError("Invalid number of spacing arguments %d" % n)
if ndim(dx[0]) != 0:
raise NotImplementedError("Non-constant spacing not implemented")
# TODO: use jax.lax loop tools if possible
a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis, dx)]
if len(axis) == 1:
a_grad = a_grad[0]
return a_grad
@_wraps(np.gradient)
def gradient(f, *args, **kwargs):
axis = kwargs.pop("axis", None)
if not len(kwargs) == 0:
raise ValueError("Only `axis` keyword is implemented")
return _gradient(f, args, axis)
@_wraps(np.isrealobj)
def isrealobj(x):
return not iscomplexobj(x)
@_wraps(np.reshape)
def reshape(a, newshape, order="C"):
try:
return a.reshape(newshape, order=order) # forward to method for ndarrays
except AttributeError:
return _reshape(a, newshape, order=order)
def _compute_newshape(a, newshape):
"""Fixes a -1 value in newshape, if present."""
# other errors, like having more than one -1, are caught downstream
try: iter(newshape)
except: iterable = False
else: iterable = True
def check(size):
return size if type(size) is Poly else core.concrete_or_error(
int, size, "The error arose in jax.numpy.reshape.")
newshape = [check(size) for size in newshape] if iterable else check(newshape)
newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
if newsize < 0:
fix = a.size // -newsize
return [d if d != -1 else fix for d in newshape]
else:
return newshape
def _reshape(a, newshape, order="C"):
computed_newshape = _compute_newshape(a, newshape)
if order == "C":
return lax.reshape(a, computed_newshape, None)
elif order == "F":
dims = np.arange(ndim(a))[::-1]
return lax.reshape(a, computed_newshape[::-1], dims).T
elif order == "A":
raise NotImplementedError("np.reshape order=A is not implemented.")
else:
raise ValueError("Unexpected value for 'order' argument: {}.".format(order))
def _reshape_method(a, *newshape, **kwargs):
order = kwargs.pop("order", "C")
if len(kwargs) == 1:
invalid_kwarg, = kwargs
msg = "'{}' is an invalid keyword argument for this function"
raise TypeError(msg.format(invalid_kwarg)) # same as NumPy error
elif kwargs:
invalid_kwargs = "'{}'".format("'".join(kwargs))
msg = "{} are invalid keyword arguments for this function"
raise TypeError(msg.format(invalid_kwargs)) # different from NumPy error
if (len(newshape) == 1 and not isinstance(newshape[0], int) and
type(newshape[0]) is not Poly):
newshape = newshape[0]
return _reshape(a, newshape, order=order)
@_wraps(np.ravel)
def ravel(a, order="C"):
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
@_wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims)
for index in multi_index:
_check_arraylike("ravel_multi_index", index)
if mode == 'raise':
core.concrete_or_error(array, index,
"The error occurred because ravel_multi_index was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if not issubdtype(_dtype(index), integer):
raise TypeError("only int indices permitted")
if mode == "raise":
if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)):
raise ValueError("invalid entry in coordinates array")
elif mode == "clip":
multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)]
elif mode == "wrap":
multi_index = [i % d for i, d in zip(multi_index, dims)]
else:
raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'")
if order == "F":
strides = np.cumprod((1,) + dims[:-1])
elif order == "C":
strides = np.cumprod((1,) + dims[1:][::-1])[::-1]
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
result = 0
for i, s in zip(multi_index, strides):
result = result + i * s
return result
_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped.
"""
@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices, shape):
indices = asarray(indices)
sizes = pad(shape, (0, 1), constant_values=1)
cumulative_sizes = cumprod(sizes[::-1])[::-1]
total_size = cumulative_sizes[0]
# Clip so raveling and unraveling an oob index will not change the behavior
clipped_indices = clip(indices, -total_size, total_size - 1)
# Add enough trailing dims to avoid conflict with flat_index
cumulative_sizes = cumulative_sizes.reshape([-1] + [1] * indices.ndim)
idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:]
return tuple(idx)
@_wraps(np.squeeze)
def squeeze(a, axis: Union[int, Tuple[int, ...]] = None):
if axis is None:
a_shape = shape(a)
axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
elif not isinstance(axis, tuple):
axis = (axis,)
return lax.squeeze(a, axis)
@_wraps(np.expand_dims)
def expand_dims(a, axis: Union[int, Tuple[int, ...]]):
if not isinstance(axis, tuple):
axis = (axis,)
return lax.expand_dims(a, axis)
@_wraps(np.swapaxes)
def swapaxes(a, axis1, axis2):
perm = np.arange(ndim(a))
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
return lax.transpose(a, perm)
@_wraps(np.moveaxis)
def moveaxis(a, source, destination):
_check_arraylike("moveaxis", a)
try:
source = (operator.index(source),)
except TypeError:
pass
try:
destination = (operator.index(destination),)
except TypeError:
pass
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
if len(source) != len(destination):
raise ValueError("Inconsistent number of elements: {} vs {}"
.format(len(source), len(destination)))
perm = [i for i in range(ndim(a)) if i not in source]
for dest, src in sorted(zip(destination, source)):
perm.insert(dest, src)
return lax.transpose(a, perm)
@_wraps(np.isclose)
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
a, b = _promote_args("isclose", asarray(a), asarray(b))
dtype = _dtype(a)
if issubdtype(dtype, inexact):
if issubdtype(dtype, complexfloating):
dtype = _complex_elem_type(dtype)
rtol = lax.convert_element_type(rtol, dtype)
atol = lax.convert_element_type(atol, dtype)
out = lax.le(
lax.abs(lax.sub(a, b)),
lax.add(atol, lax.mul(rtol, lax.abs(b))))
# This corrects the comparisons for infinite and nan values
a_inf = isinf(a)
b_inf = isinf(b)
any_inf = logical_or(a_inf, b_inf)
both_inf = logical_and(a_inf, b_inf)
# Make all elements where either a or b are infinite to False
out = logical_and(out, logical_not(any_inf))
# Make all elements where both a or b are the same inf to True
same_value = lax.eq(a, b)
same_inf = logical_and(both_inf, same_value)
out = logical_or(out, same_inf)
# Make all elements where either a or b is NaN to False
a_nan = isnan(a)
b_nan = isnan(b)
any_nan = logical_or(a_nan, b_nan)
out = logical_and(out, logical_not(any_nan))
if equal_nan:
# Make all elements where both a and b is NaN to True
both_nan = logical_and(a_nan, b_nan)
out = logical_or(out, both_nan)
return _maybe_numpy_1_13_isclose_behavior(a, out)
else:
return lax.eq(a, b)
numpy_version = tuple(map(int, np.version.version.split('.')[:2]))
if numpy_version < (1, 14):
# see discussion at https://github.com/numpy/numpy/pull/9720
def _maybe_numpy_1_13_isclose_behavior(a, out):
if size(out) == 1 and issubdtype(_dtype(a), complexfloating):
return lax.reshape(out, (1,))
else:
return out
else:
def _maybe_numpy_1_13_isclose_behavior(a, out):
return out
@_wraps(np.interp)
def interp(x, xp, fp, left=None, right=None, period=None):
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x, xp, fp = map(asarray, _promote_dtypes_inexact(x, xp, fp))
if period is not None:
if period == 0:
raise ValueError(f"period must be a non-zero value; got {period}")
period = abs(period)
x = x % period
xp = xp % period
xp, fp = lax.sort_key_val(xp, fp)
xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = concatenate([fp[-1:], fp, fp[:1]])
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
df = fp[i] - fp[i - 1]
dx = xp[i] - xp[i - 1]
delta = x - xp[i - 1]
f = where((dx == 0), fp[i], fp[i - 1] + (delta / dx) * df)
if period is None:
f = where(x < xp[0], fp[0] if left is None else left, f)
f = where(x > xp[-1], fp[-1] if right is None else right, f)
return f
@_wraps(np.in1d, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
def in1d(ar1, ar2, assume_unique=False, invert=False):
# TODO(vanderplas): use sorting-based approach for larger inputs.
ar1 = ravel(ar1)
ar2 = ravel(ar2)
if invert:
return (ar1[:, None] != ar2).all(-1)
else:
return (ar1[:, None] == ar2).any(-1)
@partial(jit, static_argnums=2)
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
"""
Helper function for intersect1d which is jit-able
"""
ar = concatenate((ar1, ar2))
if return_indices:
iota = lax.broadcasted_iota(np.int64, shape(ar), dimension=0)
aux, indices = lax.sort_key_val(ar, iota)
else:
aux = sort(ar)
mask = aux[1:] == aux[:-1]
if return_indices:
return aux, mask, indices
else:
return aux, mask
@_wraps(np.intersect1d)
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
if not assume_unique:
if return_indices:
ar1, ind1 = unique(ar1, return_index=True)
ar2, ind2 = unique(ar2, return_index=True)
else:
ar1 = unique(ar1)
ar2 = unique(ar2)
else:
ar1 = ravel(ar1)
ar2 = ravel(ar2)
if return_indices:
aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
else:
aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)
int1d = aux[:-1][mask]
if return_indices:
ar1_indices = aux_sort_indices[:-1][mask]
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
if not assume_unique:
ar1_indices = ind1[ar1_indices]
ar2_indices = ind2[ar2_indices]
return int1d, ar1_indices, ar2_indices
else:
return int1d
@_wraps(np.isin, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
def isin(element, test_elements, assume_unique=False, invert=False):
result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
return result.reshape(shape(element))
# The `jit` on `where` exists to avoid materializing constants in cases like
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@jit
def _where(condition, x=None, y=None):
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not issubdtype(_dtype(condition), bool_):
condition = lax.ne(condition, zeros_like(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = broadcast_arrays(condition, x, y)
return lax.select(condition, x, y) if np.size(x) else x
_WHERE_DOC = """\
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully.
"""
@_wraps(np.where, update_doc=False, lax_description=_WHERE_DOC)
def where(condition, x=None, y=None):
if x is None and y is None:
return nonzero(asarray(condition))
else:
return _where(condition, x, y)
@_wraps(np.select)
def select(condlist, choicelist, default=0):
if len(condlist) != len(choicelist):
msg = "condlist must have length equal to choicelist ({} vs {})"
raise ValueError(msg.format(len(condlist), len(choicelist)))
if len(condlist) == 0:
raise ValueError("condlist must be non-empty")
choices = _promote_dtypes(default, *choicelist)
choicelist = choices[1:]
output = choices[0]
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
output = where(cond, choice, output)
return output
@_wraps(np.bincount, lax_description="""\
Jax adds the optional `length` parameter which specifies the output length, and
defaults to ``x.max() + 1``. It must be specified for bincount to be compilable.
Values larger than the specified length will be discarded.
Additionally, while ``np.bincount`` raises an error if the input array contains
negative values, ``jax.numpy.bincount`` treats negative values as zero.
""")
def bincount(x, weights=None, minlength=0, *, length=None):
if not issubdtype(_dtype(x), integer):
msg = f"x argument to bincount must have an integer type; got {x.dtype}"
raise TypeError(msg)
if length is None:
length = max(x) + 1
length = _max(length, minlength)
if ndim(x) != 1:
raise ValueError("only 1-dimensional input supported.")
if weights is None:
weights = array(1, dtype=int32)
else:
if shape(x) != shape(weights):
raise ValueError("shape of weights must match shape of x.")
return ops.index_add(zeros((length,), _dtype(weights)), ops.index[clip(x, 0)], weights)
def broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [shape(arg) for arg in args]
if len(set(shapes)) == 1:
return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [broadcast_to(arg, result_shape) for arg in args]
@_wraps(np.broadcast_to, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")
def broadcast_to(arr, shape):
arr = arr if isinstance(arr, ndarray) else array(arr)
shape = canonicalize_shape(shape) # check that shape is concrete
arr_shape = _shape(arr)
if arr_shape == shape:
return arr
else:
nlead = len(shape) - len(arr_shape)
compatible = np.equal(arr_shape, shape[nlead:]) | np.equal(arr_shape, 1)
if nlead < 0 or not np.all(compatible):
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(np.not_equal(shape[nlead:], arr_shape))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
def _split(op, ary, indices_or_sections, axis=0):
axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`")
size = ary.shape[axis]
if isinstance(indices_or_sections, (tuple, list) + _arraylike_types):
indices_or_sections = [core.concrete_or_error(int, i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections]
split_indices = np.concatenate([[0], indices_or_sections, [size]])
else:
indices_or_sections = core.concrete_or_error(int, indices_or_sections,
f"in jax.numpy.{op} argument 1")
part_size, r = _divmod(size, indices_or_sections)
if r == 0:
split_indices = np.arange(indices_or_sections + 1) * part_size
elif op == "array_split":
split_indices = np.concatenate([np.arange(r + 1) * (part_size + 1),
np.arange(indices_or_sections - r) * part_size
+ ((r + 1) * (part_size + 1) - 1)])
else:
raise ValueError("array split does not result in an equal division")
starts, ends = [0] * ndim(ary), shape(ary)
_subval = lambda x, i, v: subvals(x, [(i, v)])
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
for start, end in zip(split_indices[:-1], split_indices[1:])]
@_wraps(np.split)
def split(ary, indices_or_sections, axis=0):
return _split("split", ary, indices_or_sections, axis=axis)
def _split_on_axis(np_fun, axis):
@_wraps(np_fun, update_doc=False)
def f(ary, indices_or_sections):
return split(ary, indices_or_sections, axis=axis)
return f
vsplit = _split_on_axis(np.vsplit, axis=0)
hsplit = _split_on_axis(np.hsplit, axis=1)
dsplit = _split_on_axis(np.dsplit, axis=2)
@_wraps(np.array_split)
def array_split(ary, indices_or_sections, axis=0):
return _split("array_split", ary, indices_or_sections, axis=axis)
@_wraps(np.clip)
def clip(a, a_min=None, a_max=None):
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = maximum(a_min, a)
if a_max is not None:
a = minimum(a_max, a)
return a
def _round_to_nearest_even(x):
half = lax._const(x, 0.5)
one = lax._const(x, 1)
round_val = lax.floor(x)
fraction = x - round_val
nearest_even_int = lax.sub(
round_val, lax.mul(lax._const(x, 2), lax.floor(lax.mul(half, x))))
is_odd = lax.eq(nearest_even_int, one)
return lax.select(
lax.bitwise_or(lax.gt(fraction, half),
lax.bitwise_and(lax.eq(fraction, half), is_odd)),
lax.add(round_val, one), round_val)
@_wraps(np.round, update_doc=False)
def round(a, decimals=0):
dtype = _dtype(a)
if issubdtype(dtype, integer):
if decimals < 0:
raise NotImplementedError(
"integer np.round not implemented for decimals < 0")
return a # no-op on integer types
def _round_float(x):
if decimals == 0:
return _round_to_nearest_even(x)
# TODO(phawkins): the strategy of rescaling the value isn't necessarily a
# good one since we may be left with an incorrectly rounded value at the
# end due to precision problems. As a workaround for float16, convert to
# float32,
x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x
factor = _constant_like(x, 10 ** decimals)
out = lax.div(_round_to_nearest_even(lax.mul(x, factor)), factor)
return lax.convert_element_type(out, dtype) if dtype == np.float16 else out
if issubdtype(dtype, complexfloating):
return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a)))
else:
return _round_float(a)
around = round
@_wraps(np.fix)
def fix(x, out=None):
if out is not None:
raise ValueError("fix does not support the `out` argument.")
zero = lax._const(x, 0)
return where(lax.ge(x, zero), floor(x), ceil(x))
@_wraps(np.modf)
def modf(x, out=None):
if out is not None:
raise ValueError("modf does not support the `out` argument.")
whole = fix(x)
return x - whole, whole
@_wraps(np.isfinite)
def isfinite(x):
dtype = _dtype(x)
if issubdtype(dtype, floating):
return lax.is_finite(x)
elif issubdtype(dtype, complexfloating):
return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
else:
return full_like(x, True, dtype=bool_)
@_wraps(np.isinf)
def isinf(x):
dtype = _dtype(x)
if issubdtype(dtype, floating):
return lax.eq(lax.abs(x), _constant_like(x, inf))
elif issubdtype(dtype, complexfloating):
re = lax.real(x)
im = lax.imag(x)
return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, inf)),
lax.eq(lax.abs(im), _constant_like(im, inf)))
else:
return full_like(x, False, dtype=bool_)
def _isposneginf(infinity, x):
dtype = _dtype(x)
if issubdtype(dtype, floating):
return lax.eq(x, _constant_like(x, infinity))
elif issubdtype(dtype, complexfloating):
raise ValueError("isposinf/isneginf are not well defined for complex types")
else:
return full_like(x, False, dtype=bool_)
isposinf = _wraps(np.isposinf)(lambda x: _isposneginf(inf, x))
isneginf = _wraps(np.isneginf)(lambda x: _isposneginf(-inf, x))
@_wraps(np.isnan)
def isnan(x):
return lax.bitwise_and(lax.bitwise_not(isfinite(x)),
lax.bitwise_not(isinf(x)))
@_wraps(np.nan_to_num)
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
del copy
dtype = _dtype(x)
if issubdtype(dtype, complexfloating):
return lax.complex(
nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf))
info = finfo(dtypes.canonicalize_dtype(dtype))
posinf = info.max if posinf is None else posinf
neginf = info.min if neginf is None else neginf
x = where(isnan(x), _constant_like(x, nan), x)
x = where(isposinf(x), _constant_like(x, posinf), x)
x = where(isneginf(x), _constant_like(x, neginf), x)
return x
### Reducers
def _make_reduction(name, np_fun, op, init_val, preproc=None, bool_op=None,
upcast_f16_for_computation=False):
"""Creates reduction function given a binary operation and monoid identity."""
bool_op = bool_op or op
@_wraps(np_fun)
def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
if out is not None:
raise ValueError("reduction does not support the `out` argument.")
_check_arraylike(name, a)
a = a if isinstance(a, ndarray) else asarray(a)
a = preproc(a) if preproc else a
dims = _reduction_dims(a, axis)
result_dtype = dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a))))
if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
computation_dtype = promote_types(result_dtype, float32)
else:
computation_dtype = result_dtype
a = lax.convert_element_type(a, computation_dtype)
result = lax.reduce(a, _reduction_init_val(a, init_val),
op if computation_dtype != np.bool_ else bool_op, dims)
if keepdims:
result = expand_dims(result, dims)
return lax.convert_element_type(result, dtype or result_dtype)
return reduction
def _reduction_dims(a, axis):
if axis is None:
return tuple(range(ndim(a)))
elif isinstance(axis, (np.ndarray, tuple, list)):
if len(axis) != len(set(axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
return tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
elif isinstance(axis, int):
return (_canonicalize_axis(axis, ndim(a)),)
else:
raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
def _reduction_init_val(a, init_val):
a_dtype = dtypes.canonicalize_dtype(_dtype(a))
if a_dtype == 'bool':
return np.array(init_val > 0, dtype=a_dtype)
try:
return np.array(init_val, dtype=a_dtype)
except OverflowError:
assert issubdtype(a_dtype, integer)
sign, info = np.sign(init_val), iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_)
sum = _make_reduction("sum", np.sum, lax.add, 0, upcast_f16_for_computation=True,
bool_op=lax.bitwise_or)
product = prod = _make_reduction("prod", np.prod, lax.mul, 1, bool_op=lax.bitwise_and,
upcast_f16_for_computation=True)
amax = max = _make_reduction("max", np.max, lax.max, -np.inf)
amin = min = _make_reduction("min", np.min, lax.min, np.inf)
all = alltrue = _make_reduction("all", np.all, lax.bitwise_and, True, _cast_to_bool)
any = sometrue = _make_reduction("any", np.any, lax.bitwise_or, False, _cast_to_bool)
@_wraps(np.mean)
def mean(a, axis=None, dtype=None, out=None, keepdims=False):
if out is not None:
raise ValueError("mean does not support the `out` argument.")
if axis is None:
normalizer = size(a)
else:
normalizer = np.prod(np.take(shape(a), axis))
if dtype is None:
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
dtype = float_
else:
dtype = _dtype(a)
return lax.div(
sum(a, axis, dtype=dtype, keepdims=keepdims),
lax.convert_element_type(normalizer, dtype))
@_wraps(np.average)
def average(a, axis=None, weights=None, returned=False):
a = asarray(a)
if weights is None: # Treat all weights as 1
avg = mean(a, axis=axis)
if axis is None:
weights_sum = full((), size(a), dtype=avg.dtype)
else:
weights_sum = full_like(avg, a.shape[axis], dtype=avg.dtype)
else:
weights = asarray(weights)
if issubdtype(a.dtype, inexact):
out_dtype = result_type(a.dtype, weights.dtype)
else:
out_dtype = result_type(a.dtype, weights.dtype, float_)
out_dtype = dtypes.canonicalize_dtype(out_dtype)
a_shape = shape(a)
a_ndim = len(a_shape)
weights_shape = shape(weights)
axis = None if axis is None else _canonicalize_axis(axis, a_ndim)
if a_shape != weights_shape:
# Make sure the dimensions work out
if axis is None:
raise ValueError("Axis must be specified when shapes of a and "
"weights differ.")
if len(weights_shape) != 1:
raise ValueError("1D weights expected when shapes of a and "
"weights differ.")
if weights_shape[0] != a_shape[axis]:
raise ValueError("Length of weights not "
"compatible with specified axis.")
weights = broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
weights = moveaxis(weights, -1, axis)
weights_sum = sum(weights, axis=axis, dtype=out_dtype)
avg = sum(multiply(a, weights), axis=axis, dtype=out_dtype) / weights_sum
if returned:
if avg.shape != weights_sum.shape:
weights_sum = broadcast_to(weights_sum, avg.shape)
return avg, weights_sum
return avg
@_wraps(np.var)
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if out is not None:
raise ValueError("var does not support the `out` argument.")
a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True)
centered = a - a_mean
if issubdtype(centered.dtype, complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
if axis is None:
normalizer = size(a)
else:
normalizer = np.prod(np.take(shape(a), axis))
normalizer = normalizer - ddof
result = sum(centered, axis, keepdims=keepdims)
out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
return lax.convert_element_type(out, dtype)
def _var_promote_types(a_dtype, dtype):
if dtype:
if (not issubdtype(dtype, complexfloating) and
issubdtype(a_dtype, complexfloating)):
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/google/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
a_dtype = promote_types(a_dtype, dtype)
else:
if not issubdtype(a_dtype, inexact):
dtype = a_dtype = float_
else:
dtype = _complex_elem_type(a_dtype)
a_dtype = promote_types(a_dtype, float32)
return a_dtype, dtype
@_wraps(np.std)
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if out is not None:
raise ValueError("std does not support the `out` argument.")
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))
@_wraps(np.ptp)
def ptp(a, axis=None, out=None, keepdims=False):
if out is not None:
raise ValueError("ptp does not support the `out` argument.")
x = amax(a, axis=axis, keepdims=keepdims)
y = amin(a, axis=axis, keepdims=keepdims)
return lax.sub(x, y)
@_wraps(np.allclose)
def allclose(a, b, rtol=1e-05, atol=1e-08):
return all(isclose(a, b, rtol, atol))
@_wraps(np.count_nonzero)
def count_nonzero(a, axis=None, keepdims=False):
return sum(lax.ne(a, _constant_like(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
_NONZERO_DOC = """\
At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero`
because its output shape is data-dependent.
"""
@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
def nonzero(a):
# Note: this function cannot be jitted because its output has a dynamic
# shape.
a = atleast_1d(a)
dims = shape(a)
ndims = len(dims)
ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
d = concatenate(ds, axis=-1)
indexes = d[a != 0]
return tuple(indexes[..., i] for i in range(ndims))
@_wraps(np.flatnonzero)
def flatnonzero(a):
return nonzero(ravel(a))[0]
def _make_nan_reduction(np_reduction, jnp_reduction, init_val, nan_if_all_nan):
@_wraps(np_reduction)
def nan_reduction(a, axis=None, out=None, keepdims=False, **kwargs):
out = jnp_reduction(where(isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, out=out, keepdims=keepdims, **kwargs)
if nan_if_all_nan:
return where(all(isnan(a), axis=axis, keepdims=keepdims),
_constant_like(a, nan), out)
else:
return out
return nan_reduction
nanmin = _make_nan_reduction(np.nanmin, min, inf, nan_if_all_nan=True)
nanmax = _make_nan_reduction(np.nanmax, max, -inf, nan_if_all_nan=True)
nansum = _make_nan_reduction(np.nansum, sum, 0, nan_if_all_nan=False)
nanprod = _make_nan_reduction(np.nanprod, prod, 1, nan_if_all_nan=False)
@_wraps(np.nanmean)
def nanmean(a, axis=None, dtype=None, out=None, keepdims=False):
if out is not None:
raise ValueError("nanmean does not support the `out` argument.")
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
return mean(a, axis, dtype, out, keepdims)
if dtype is None:
dtype = _dtype(a)
nan_mask = logical_not(isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims)
normalizer = lax.convert_element_type(normalizer, dtype)
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer)
return td
@_wraps(np.nanvar)
def nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if out is not None:
raise ValueError("nanvar does not support the `out` argument.")
a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True)
centered = a - a_mean
if issubdtype(centered.dtype, complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims)
normalizer = normalizer - ddof
if config.omnistaging_enabled:
normalizer_mask = lax.le(normalizer, 0)
else:
zero = lax.full_like(normalizer, 0, shape=())
normalizer_mask = lax.le(normalizer, zero)
result = nansum(centered, axis, keepdims=keepdims)
result = where(normalizer_mask, nan, result)
divisor = where(normalizer_mask, 1, normalizer)
out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
return lax.convert_element_type(out, dtype)
@_wraps(np.nanstd)
def nanstd(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if out is not None:
raise ValueError("nanstd does not support the `out` argument.")
return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))
def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0):
# We want to allow XLA to fuse the pad and reduce-window operators to
# avoid materializing the padded output.
# Consider removing `jit` once again if reduce-window is generalized to
# support arbitrary padding.
@partial(jit, static_argnums=(1, 2))
def _cumulative_reduction(a, axis, dtype):
if axis is None or isscalar(a):
a = ravel(a)
axis = 0
a_shape = list(shape(a))
num_dims = len(a_shape)
axis = _canonicalize_axis(axis, num_dims)
if fill_nan:
a = where(isnan(a), _constant_like(a, fill_value), a)
if not dtype and _dtype(a) == bool_:
dtype = int_
if dtype:
a = lax.convert_element_type(a, dtype)
return reduction(a, axis)
@_wraps(np_reduction)
def cumulative_reduction(a, axis=None, dtype=None):
# jit doesn't support kwargs as static_args.
return _cumulative_reduction(a, axis, dtype)
return cumulative_reduction
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
cumproduct = cumprod
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)
@_wraps(np.unwrap)
def unwrap(p, discont=pi, axis=-1):
dd = diff(p, axis=axis)
ddmod = mod(dd + pi, 2 * pi) - pi
ddmod = where((ddmod == -pi) & (dd > 0), pi, ddmod)
ph_correct = where(abs(dd) < discont, 0, ddmod - dd)
up = concatenate((
lax.slice_in_dim(p, 0, 1, axis=axis),
lax.slice_in_dim(p, 1, None, axis=axis) + cumsum(ph_correct, axis=axis)
), axis=axis)
return up
### Array-creation functions
def _check_no_padding(axis_padding, mode):
if (axis_padding[0] > 0 or axis_padding[1] > 0):
msg = "Cannot apply '{}' padding to empty axis"
raise ValueError(msg.format(mode))
def _pad_constant(array, pad_width, constant_values):
nd = ndim(array)
constant_values = broadcast_to(asarray(constant_values), (nd, 2))
constant_values = lax.convert_element_type(constant_values, array.dtype)
for i in range(nd):
widths = [(0, 0, 0)] * nd
widths[i] = (pad_width[i, 0], 0, 0)
array = lax.pad(array, constant_values[i, 0], widths)
widths[i] = (0, pad_width[i, 1], 0)
array = lax.pad(array, constant_values[i, 1], widths)
return array
def _pad_wrap(array, pad_width):
for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "wrap")
continue
size = array.shape[i]
repeats, (left_remainder, right_remainder) = _divmod(pad_width[i], size)
total_repeats = repeats.sum() + 1
parts = []
if left_remainder:
parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)]
parts += total_repeats * [array]
if right_remainder:
parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)]
array = lax.concatenate(parts, dimension=i)
return array
def _pad_symmetric_or_reflect(array, pad_width, mode):
assert mode in ("symmetric", "reflect")
for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], mode)
continue
n = array.shape[i]
rarray = lax.rev(array, dimensions=(i,))
offset = 1 if (mode == "reflect" and n > 1) else 0
def build_padding(padding, forward):
xs = []
delta = n - offset
while padding > delta:
padding -= delta
p = array if forward else rarray
xs.append(lax.slice_in_dim(p, offset, n, axis=i))
forward = not forward
if padding > 0:
x = lax.slice_in_dim(array if forward else rarray, offset,
padding + offset, axis=i)
xs.append(x)
return xs
parts = reversed(build_padding(pad_width[i, 0], forward=True))
parts = [lax.rev(x, dimensions=(i,)) for x in parts]
parts += [array]
parts += build_padding(pad_width[i, 1], forward=False)
array = lax.concatenate(parts, dimension=i)
return array
def _pad_edge(array, pad_width):
nd = ndim(array)
for i in range(nd):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "edge")
continue
n = array.shape[i]
npad_before, npad_after = pad_width[i]
edge_before = lax.slice_in_dim(array, 0, 1, axis=i)
pad_before = repeat(edge_before, npad_before, axis=i)
edge_after = lax.slice_in_dim(array, n-1, n, axis=i)
pad_after = repeat(edge_after, npad_after, axis=i)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
@partial(jit, static_argnums=(1, 2))
def _pad(array, pad_width, mode, constant_values):
array = asarray(array)
nd = ndim(array)
pad_width = np.broadcast_to(np.asarray(pad_width), (nd, 2))
if np.any(pad_width < 0):
raise ValueError("index can't contain negative values")
if mode == "constant":
return _pad_constant(array, pad_width, constant_values)
elif mode == "wrap":
return _pad_wrap(array, pad_width)
elif mode in ("symmetric", "reflect"):
return _pad_symmetric_or_reflect(array, pad_width, mode)
elif mode == "edge":
return _pad_edge(array, pad_width)
else:
msg = "Unimplemented padding mode '{}' for np.pad."
raise NotImplementedError(msg.format(mode))
@_wraps(np.pad)
def pad(array, pad_width, mode='constant', constant_values=0):
if isinstance(pad_width, list):
pad_width = tuple(pad_width)
return _pad(array, pad_width, mode, constant_values)
@_wraps(np.stack)
def stack(arrays, axis=0):
if not len(arrays):
raise ValueError("Need at least one array to stack.")
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)
@_wraps(np.tile)
def tile(A, reps):
if isinstance(reps, int):
reps = (reps,)
A_shape = (1,) * (len(reps) - ndim(A)) + shape(A)
reps = (1,) * (len(A_shape) - len(reps)) + tuple(reps)
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
[k for pair in zip(reps, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps)))
@_wraps(np.concatenate)
def concatenate(arrays, axis=0):
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0)
axis = _canonicalize_axis(axis, ndim(arrays[0]))
arrays = _promote_dtypes(*arrays)
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653).
k = 16
if len(arrays) == 1:
return array(arrays[0])
else:
while len(arrays) > 1:
arrays = [lax.concatenate(arrays[i:i+k], axis)
for i in range(0, len(arrays), k)]
return arrays[0]
@_wraps(np.vstack)
def vstack(tup):
return concatenate([atleast_2d(m) for m in tup], axis=0)
row_stack = vstack
@_wraps(np.hstack)
def hstack(tup):
arrs = [atleast_1d(m) for m in tup]
if arrs[0].ndim == 1:
return concatenate(arrs, 0)
return concatenate(arrs, 1)
@_wraps(np.dstack)
def dstack(tup):
return concatenate([atleast_3d(m) for m in tup], axis=2)
@_wraps(np.column_stack)
def column_stack(tup):
arrays = []
for v in tup:
arr = array(v)
if arr.ndim < 2:
arr = atleast_2d(arr).T
arrays.append(arr)
return concatenate(arrays, 1)
def _atleast_nd(x, n):
m = ndim(x)
return lax.broadcast(x, (1,) * (n - m)) if m < n else x
def _block(xs):
if isinstance(xs, tuple):
raise ValueError("jax.numpy.block does not allow tuples, got {}"
.format(xs))
elif isinstance(xs, list):
if len(xs) == 0:
raise ValueError("jax.numpy.block does not allow empty list arguments")
xs, depths = unzip2([_block(x) for x in xs])
if _any(d != depths[0] for d in depths[1:]):
raise ValueError("Mismatched list depths in jax.numpy.block")
rank = _max(depths[0], _max(ndim(x) for x in xs))
xs = [_atleast_nd(x, rank) for x in xs]
return concatenate(xs, axis=-depths[0]), depths[0] + 1
else:
return asarray(xs), 1
@_wraps(np.block)
@jit
def block(arrays):
out, _ = _block(arrays)
return out
@_wraps(np.atleast_1d, update_doc=False)
def atleast_1d(*arys):
if len(arys) == 1:
arr = array(arys[0])
return arr if ndim(arr) >= 1 else reshape(arr, -1)
else:
return [atleast_1d(arr) for arr in arys]
@_wraps(np.atleast_2d, update_doc=False)
def atleast_2d(*arys):
if len(arys) == 1:
arr = array(arys[0])
if ndim(arr) >= 2:
return arr
elif ndim(arr) == 1:
return expand_dims(arr, axis=0)
else:
return expand_dims(arr, axis=(0, 1))
else:
return [atleast_2d(arr) for arr in arys]
@_wraps(np.atleast_3d, update_doc=False)
def atleast_3d(*arys):
if len(arys) == 1:
arr = array(arys[0])
if ndim(arr) == 0:
arr = expand_dims(arr, axis=(0, 1, 2))
elif ndim(arr) == 1:
arr = expand_dims(arr, axis=(0, 2))
elif ndim(arr) == 2:
arr = expand_dims(arr, axis=2)
return arr
else:
return [atleast_3d(arr) for arr in arys]
@_wraps(np.array)
def array(object, dtype=None, copy=True, order="K", ndmin=0):
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
lax._check_user_dtype_supported(dtype, "array")
dtype = dtype and dtypes.canonicalize_dtype(dtype)
if _can_call_numpy_array(object):
object = _np_array(object, dtype=dtype, ndmin=ndmin)
assert type(object) not in dtypes.python_scalar_dtypes
if type(object) is np.ndarray:
out = _device_put_raw(object)
if dtype: assert _dtype(out) == dtype
elif isinstance(object, (DeviceArray, core.Tracer)):
if isinstance(object, DeviceArray) and copy:
# We perform a copy by bouncing back to the host
# TODO(phawkins): add a device runtime function to copy a buffer
out = _device_put_raw(_np_asarray(object))
else:
out = object
elif isinstance(object, (list, tuple)):
if object:
out = stack([array(elt, dtype=dtype) for elt in object])
else:
out = _device_put_raw(_np_array([], dtype=dtype))
else:
try:
view = memoryview(object)
except TypeError:
pass # `object` does not support the buffer interface.
else:
return array(_np_asarray(view), dtype, copy)
raise TypeError("Unexpected input type for array: {}".format(type(object)))
if dtype and _dtype(out) != dtype:
out = lax.convert_element_type(out, dtype)
if ndmin > ndim(out):
out = lax.broadcast(out, (1,) * (ndmin - ndim(out)))
return out
def _can_call_numpy_array(x):
return _all(not isinstance(l, (core.Tracer, DeviceArray))
for l in tree_leaves(x))
@_wraps(np.asarray)
def asarray(a, dtype=None, order=None):
lax._check_user_dtype_supported(dtype, "asarray")
return array(a, dtype=dtype, copy=False, order=order)
@_wraps(np.zeros_like)
def zeros_like(a, dtype=None):
lax._check_user_dtype_supported(dtype, "zeros_like")
return lax.full_like(a, 0, dtype)
@_wraps(np.ones_like)
def ones_like(a, dtype=None):
lax._check_user_dtype_supported(dtype, "ones_like")
return lax.full_like(a, 1, dtype)
@_wraps(np.full)
def full(shape, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full")
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, fill_value, dtype)
@_wraps(np.full_like)
def full_like(a, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full_like")
return lax.full_like(a, fill_value, dtype)
@_wraps(np.zeros)
def zeros(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "zeros")
dtype = float_ if dtype is None else dtype
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, 0, dtype)
@_wraps(np.ones)
def ones(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "ones")
dtype = float_ if dtype is None else dtype
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, 1, dtype)
@_wraps(np.array_equal)
def array_equal(a1, a2, equal_nan=False):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
if shape(a1) != shape(a2):
return False
eq = asarray(a1 == a2)
if equal_nan:
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
return all(eq)
@_wraps(np.array_equiv)
def array_equiv(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
try:
eq = equal(a1, a2)
except ValueError:
# shapes are not broadcastable
return False
return all(eq)
# We can't create uninitialized arrays in XLA; use zeros for empty.
empty_like = zeros_like
empty = zeros
@_wraps(np.eye)
def eye(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "eye")
dtype = float_ if dtype is None else dtype
M = N if M is None else M
k = int(k)
if N < 0 or M < 0:
msg = "negative dimensions are not allowed, got {} and {}"
raise ValueError(msg.format(N, M))
if k is not None:
k_dtype = _dtype(k)
if not issubdtype(k_dtype, integer):
msg = "eye argument `k` must be of integer dtype, got {}"
raise TypeError(msg.format(k_dtype))
return lax._eye(dtype, (N, M), k)
@_wraps(np.identity)
def identity(n, dtype=None):
lax._check_user_dtype_supported(dtype, "identity")
return eye(n, dtype=dtype)
@_wraps(np.arange)
def arange(start, stop=None, step=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, _np_asarray)
msg = "It arose in jax.numpy.arange argument `{}`.".format
if stop is None and step is None:
start = require(start, msg("stop"))
dtype = dtype or _dtype(start)
return lax.iota(dtype, np.ceil(start)) # avoids materializing
else:
start = require(start, msg("start"))
stop = None if stop is None else require(stop, msg("stop"))
step = None if step is None else require(step, msg("step"))
if dtype is None:
dtype = _dtype(start, *(x for x in [stop, step] if x is not None))
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
def _wrap_numpy_nullary_function(f):
"""Adapts `f` to return a DeviceArray instead of an np.ndarray.
`f` cannot have any non-static array arguments.
"""
@_wraps(f, update_doc=False)
def wrapper(*args, **kwargs):
return asarray(f(*args, **kwargs))
return wrapper
@_wraps(np.linspace)
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
"""Implementation of linspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "linspace")
if num < 0:
raise ValueError("Number of samples, %s, must be non-negative." % num)
dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_))
computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
broadcast_start = broadcast_to(start, bounds_shape)
broadcast_stop = broadcast_to(stop, bounds_shape)
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
bounds_shape.insert(axis, 1)
iota_shape = [1,] * len(bounds_shape)
iota_shape[axis] = num
div = (num - 1) if endpoint else num
if num > 1:
delta = lax.convert_element_type(stop - start, computation_dtype) / div
if issubdtype(dtype, integer):
# This is similar to how numpy computes linspace, but it
# can fail to recover the endpoints in float32 arithmetic.
out = (reshape(broadcast_start, bounds_shape) +
reshape(lax.iota(dtype, num), iota_shape) *
reshape(delta, bounds_shape))
else:
# This approach recovers the endpoints with float32 arithmetic,
# but can lead to rounding errors for integer outputs.
step = reshape(lax.iota(computation_dtype, num), iota_shape) / div
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)
elif num == 1:
delta = nan if endpoint else stop - start
out = reshape(broadcast_start, bounds_shape)
else: # num == 0 degenerate case, match numpy behavior
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
empty_shape.insert(axis, 0)
delta = nan
out = reshape(array([], dtype=dtype), empty_shape)
if retstep:
return lax.convert_element_type(out, dtype), delta
else:
return lax.convert_element_type(out, dtype)
@_wraps(np.logspace)
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
"""Implementation of logspace differentiable in start and stop args."""
dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_))
computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
lin = linspace(start, stop, num,
endpoint=endpoint, retstep=False, dtype=None, axis=axis)
return lax.convert_element_type(power(base, lin), dtype)
@_wraps(np.geomspace)
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
"""Implementation of geomspace differentiable in start and stop args."""
dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_))
computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2
res = signflip * logspace(log10(signflip * start),
log10(signflip * stop), num,
endpoint=endpoint, base=10.0,
dtype=computation_dtype, axis=0)
if axis != 0:
res = moveaxis(res, 0, axis)
return lax.convert_element_type(res, dtype)
@_wraps(np.meshgrid)
def meshgrid(*args, **kwargs):
indexing = kwargs.get("indexing", "xy")
sparse = kwargs.get("sparse", False)
copy = kwargs.get("copy", True)
if not copy:
raise ValueError("jax.numpy.meshgrid only supports copy=True")
args = list(args)
if indexing == "xy":
if len(args) >= 2:
args[0], args[1] = args[1], args[0]
elif indexing != "ij":
raise ValueError("Valid values for indexing are 'xy' and 'ij', got {}"
.format(indexing))
shape = []
for i, a in enumerate(args):
args[i] = a = asarray(a)
if len(a.shape) != 1:
msg = "Arguments to jax.numpy.meshgrid must be 1D, got shape {}"
raise ValueError(msg.format(a.shape))
shape.append(1 if sparse else a.shape[0])
output = []
for i, a in enumerate(args):
a = asarray(a)
s = shape
if sparse:
s = list(s)
s[i] = a.shape[0]
output.append(lax.broadcast_in_dim(a, s, (i,)))
if indexing == "xy" and len(args) >= 2:
output[0], output[1] = output[1], output[0]
return output
@_wraps(np.i0)
def i0(x):
x = lax.abs(*_promote_args_inexact("i0", x))
return lax.mul(lax.exp(x), lax.bessel_i0e(x))
@_wraps(np.ix_)
def ix_(*args):
n = len(args)
output = []
for i, a in enumerate(args):
a = asarray(a)
if len(a.shape) != 1:
msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
raise ValueError(msg.format(a.shape))
if _dtype(a) == bool_:
raise NotImplementedError(
"Boolean arguments to jax.numpy.ix_ are not implemented")
shape = [1] * n
shape[i] = a.shape[0]
if a.size == 0:
# Numpy uses an integer index type for empty arrays.
output.append(lax.full(shape, np.zeros((), np.intp)))
else:
output.append(lax.broadcast_in_dim(a, shape, (i,)))
return tuple(output)
@_wraps(np.indices)
def indices(dimensions, dtype=int32, sparse=False):
dimensions = tuple(dimensions)
N = len(dimensions)
output = []
s = dimensions
for i, dim in enumerate(dimensions):
idx = lax.iota(dtype, dim)
if sparse:
s = (1,)*i + (dim,) + (1,)*(N - i - 1)
output.append(lax.broadcast_in_dim(idx, s, (i,)))
if sparse:
return tuple(output)
return stack(output, 0) if output else array([], dtype=dtype)
_TOTAL_REPEAT_LENGTH_DOC = """\
Jax adds the optional `total_repeat_length` parameter which specifies the total
number of repeat, and defaults to sum(repeats). It must be specified for repeat
to be compilable. If `sum(repeats)` is larger than the specified
`total_repeat_length` the remaining values will be discarded. In the case of
`sum(repeats)` being smaller than the specified target length, the final value
will be repeated.
"""
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a, repeats, axis=None, *, total_repeat_length=None):
if axis is None:
a = ravel(a)
axis = 0
# If total_repeat_length is not given, can't compile, use a default.
if total_repeat_length is None:
repeats = core.concrete_or_error(np.array, repeats, "It arose in jax.numpy.repeat.")
repeats = np.ravel(repeats)
if ndim(a) != 0:
repeats = np.broadcast_to(repeats, [a.shape[axis]])
total_repeat_length = np.sum(repeats)
else:
repeats = ravel(repeats)
if ndim(a) != 0:
repeats = broadcast_to(repeats, [a.shape[axis]])
# Special case when a is a scalar.
if ndim(a) == 0:
if repeats.shape == (1,):
return full([total_repeat_length], a)
else:
raise ValueError('`repeat` with a scalar parameter `a` is only '
'implemented for scalar values of the parameter `repeats`.')
# Special case if total_repeat_length is zero.
if total_repeat_length == 0:
result_shape = list(a.shape)
result_shape[axis] = 0
return reshape(array([], dtype=a.dtype), result_shape)
# If repeats is on a zero sized axis, then return the array.
if a.shape[axis] == 0:
return a
# This implementation of repeat avoid having to instantiate a large.
# intermediate tensor.
# Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat.
exclusive_repeats = roll(repeats, shift=1).at[0].set(0)
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3]
scatter_indices = cumsum(exclusive_repeats)
# Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0]
block_split_indicators = ops.index_add(
x=zeros([total_repeat_length], dtype=int32),
idx=scatter_indices,
y=1)
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
gather_indices = cumsum(block_split_indicators) - 1
return take(a, gather_indices, axis=axis)
@_wraps(np.tri)
def tri(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
return lax._tri(dtype, (N, M), k)
@_wraps(np.tril)
def tril(m, k=0):
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
mask = tri(*m_shape[-2:], k=k, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
@_wraps(np.triu, update_doc=False)
def triu(m, k=0):
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
@_wraps(np.trace)
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
if out:
raise NotImplementedError("The 'out' argument to trace is not supported.")
lax._check_user_dtype_supported(dtype, "trace")
axis1 = _canonicalize_axis(axis1, ndim(a))
axis2 = _canonicalize_axis(axis2, ndim(a))
a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(np.int_)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int
# Move the axis? dimensions to the end.
perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2]
perm = perm + [axis1, axis2]
a = lax.transpose(a, perm)
# Mask out the diagonal and reduce.
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
a, zeros_like(a))
return sum(a, axis=(-2, -1), dtype=dtype)
def _wrap_indices_function(f):
@_wraps(f, update_doc=False)
def wrapper(*args, **kwargs):
return tuple(asarray(x) for x in f(*args, **kwargs))
return wrapper
tril_indices = _wrap_indices_function(np.tril_indices)
triu_indices = _wrap_indices_function(np.triu_indices)
mask_indices = _wrap_indices_function(np.mask_indices)
@_wraps(np.triu_indices_from)
def triu_indices_from(arr, k=0):
return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])
@_wraps(np.tril_indices_from)
def tril_indices_from(arr, k=0):
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
@_wraps(np.diag_indices)
def diag_indices(n, ndim=2):
if n < 0:
raise ValueError("n argument to diag_indices must be nonnegative, got {}"
.format(n))
if ndim < 0:
raise ValueError("ndim argument to diag_indices must be nonnegative, got {}"
.format(ndim))
return (lax.iota(int_, n),) * ndim
@_wraps(np.diag_indices_from)
def diag_indices_from(arr):
if not arr.ndim >= 2:
raise ValueError("input array must be at least 2-d")
if len(set(arr.shape)) != 1:
raise ValueError("All dimensions of input must be of equal length")
return diag_indices(arr.shape[0], ndim=arr.ndim)
@_wraps(np.diagonal)
def diagonal(a, offset=0, axis1=0, axis2=1):
a_shape = shape(a)
a_ndims = len(a_shape)
# Move the two dimensions to the end.
axis1 = _canonicalize_axis(axis1, a_ndims)
axis2 = _canonicalize_axis(axis2, a_ndims)
perm = [i for i in range(a_ndims) if i != axis1 and i != axis2]
perm = perm + [axis1, axis2]
a = lax.transpose(a, perm)
# Mask out the diagonal and reduce over one of the axes
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
a, zeros_like(a))
reduce_axis = -2 if offset < 0 else -1
d = sum(a, axis=reduce_axis, dtype=_dtype(a))
# Slice out the correct diagonal size.
diag_size = _max(0, _min(a_shape[axis1] + _min(offset, 0),
a_shape[axis2] - _max(offset, 0)))
return lax.slice_in_dim(d, 0, diag_size, axis=-1)
@_wraps(np.diag)
def diag(v, k=0):
v_shape = shape(v)
if len(v_shape) == 1:
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
n = v_shape[0] + _abs(k)
v = lax.pad(v, zero(v), ((_max(0, k), _max(0, -k), 0),))
return where(eye(n, k=k, dtype=bool), v, zeros_like(v))
elif len(v_shape) == 2:
return diagonal(v, offset=k)
else:
raise ValueError("diag input must be 1d or 2d")
_SCALAR_VALUE_DOC="""\
This differs from np.diagflat for some scalar values of v,
jax always returns a two-dimensional array, whereas numpy may
return a scalar depending on the type of v.
"""
@_wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC)
def diagflat(v, k=0):
v = ravel(v)
v_length = len(v)
adj_length = v_length + _abs(k)
res = zeros(adj_length*adj_length, dtype=v.dtype)
i = arange(0, adj_length-_abs(k))
if (k >= 0):
fi = i+k+i*adj_length
else:
fi = i+(i-k)*adj_length
res = ops.index_update(res, ops.index[fi], v)
res = res.reshape(adj_length,adj_length)
return res
@_wraps(np.polyval)
def polyval(p, x):
if isinstance(p, np.poly1d):
p = np.asarray(p)
if isinstance(x, np.poly1d):
y = 0
else:
y = zeros_like(x)
for i in range(len(p)):
y = y * x + p[i]
return y
@_wraps(np.polyadd)
def polyadd(a1, a2):
a1 = asarray(a1)
a2 = asarray(a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
else:
return a2.at[-a1.shape[0]:].add(a1)
@_wraps(np.polyder)
def polyder(p, m=1):
p = asarray(p)
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
if m % 1:
raise ValueError("m must be an integer")
coeff = (arange(len(p), m, -1) - 1 - arange(m)[:, newaxis]).prod(0)
return p[:-m] * coeff
@_wraps(np.trim_zeros)
def trim_zeros(filt, trim='fb'):
nz = asarray(filt) == 0
if all(nz):
return empty(0, _dtype(filt))
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
_LEADING_ZEROS_DOC="""\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
"""
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1, a2, *, trim_leading_zeros=False):
if isinstance(a1, np.poly1d):
a1 = asarray(a1)
if isinstance(a2, np.poly1d):
a2 = asarray(a2)
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0.])
if len(a2) == 0:
a2 = asarray([0.])
val = convolve(a1, a2, mode='full')
return val
@_wraps(np.polysub)
def polysub(a1, a2):
return polyadd(asarray(a1), -asarray(a2))
@_wraps(np.append)
def append(arr, values, axis=None):
if axis is None:
return concatenate([ravel(arr), ravel(values)], 0)
else:
return concatenate([arr, values], axis=axis)
@_wraps(np.apply_along_axis)
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
num_dims = ndim(arr)
axis = _canonicalize_axis(axis, num_dims)
func = lambda arr: func1d(arr, *args, **kwargs)
for i in range(1, num_dims - axis):
func = jax.vmap(func, in_axes=i, out_axes=-1)
for i in range(axis):
func = jax.vmap(func, in_axes=0, out_axes=0)
return func(arr)
@_wraps(np.apply_over_axes)
def apply_over_axes(func, a, axes):
for axis in axes:
b = func(a, axis=axis)
if b.ndim == a.ndim:
a = b
elif b.ndim == a.ndim - 1:
a = expand_dims(b, axis)
else:
raise ValueError("function is not returning an array of the correct shape")
return a
### Tensor contraction operations
@_wraps(np.dot, lax_description=_PRECISION_DOC)
def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("dot", a, b)
a, b = _promote_dtypes(a, b)
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
return lax.mul(a, b)
if _max(a_ndim, b_ndim) <= 2:
return lax.dot(a, b, precision=precision)
if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
batch_dims = ((), ())
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
@_wraps(np.matmul, lax_description=_PRECISION_DOC)
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("matmul", a, b)
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)
a, b = _promote_dtypes(a, b)
a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims = shape(a)[:-2] if a_is_mat else ()
b_batch_dims = shape(b)[:-2] if b_is_mat else ()
num_batch_dims = _max(len(a_batch_dims), len(b_batch_dims))
a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims
b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims
# Dimensions to squeeze from the inputs.
a_squeeze = []
b_squeeze = []
# Positions of batch dimensions in squeezed inputs.
a_batch = []
b_batch = []
# Desired index in final output of each kind of dimension, in the order that
# lax.dot_general will emit them.
idx_batch = []
idx_a_other = [] # other = non-batch, non-contracting.
idx_b_other = []
for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)):
if ba is None:
idx_b_other.append(i)
elif bb is None:
idx_a_other.append(i)
elif ba == 1:
idx_b_other.append(i)
a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze))
elif bb == 1:
idx_a_other.append(i)
b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze))
elif ba == bb:
a_batch.append(len(idx_batch) + len(idx_a_other))
b_batch.append(len(idx_batch) + len(idx_b_other))
idx_batch.append(i)
else:
raise ValueError("Incompatible shapes for matmul arguments: {} and {}"
.format(shape(a), shape(b)))
if a_is_mat: idx_a_other.append(num_batch_dims)
if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat)
perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other]))
a = lax.squeeze(a, tuple(a_squeeze))
b = lax.squeeze(b, tuple(b_squeeze))
out = lax.dot_general(
a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
precision=precision)
return lax.transpose(out, perm)
@_wraps(np.vdot, lax_description=_PRECISION_DOC)
def vdot(a, b, *, precision=None):
if issubdtype(_dtype(a), complexfloating):
a = conj(a)
return dot(a.ravel(), b.ravel(), precision=precision)
@_wraps(np.tensordot, lax_description=_PRECISION_DOC)
def tensordot(a, b, axes=2, *, precision=None):
_check_arraylike("tensordot", a, b)
a_ndim = ndim(a)
b_ndim = ndim(b)
a, b = _promote_dtypes(a, b)
if type(axes) is int:
if axes > _min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
raise TypeError(msg.format(axes, a.shape, b.shape))
contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes))
elif type(axes) in (list, tuple) and len(axes) == 2:
ax1, ax2 = axes
if type(ax1) == type(ax2) == int:
contracting_dims = ((_canonicalize_axis(ax1, a_ndim),),
(_canonicalize_axis(ax2, b_ndim),))
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
if len(ax1) != len(ax2):
msg = "tensordot requires axes lists to have equal length, got {} and {}."
raise TypeError(msg.format(ax1, ax2))
contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1),
tuple(_canonicalize_axis(i, b_ndim) for i in ax2))
else:
msg = "tensordot requires both axes lists to be either ints, tuples or lists, got {} and {}"
raise TypeError(msg.format(ax1, ax2))
else:
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
"of lists/tuples of ints.")
raise TypeError(msg)
return lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision)
@_wraps(np.einsum, lax_description=_PRECISION_DOC)
def einsum(*operands, optimize='greedy', precision=None):
optimize = 'greedy' if optimize is True else optimize
# using einsum_call=True here is an internal api for opt_einsum
operands, contractions = opt_einsum.contract_path(
*operands, einsum_call=True, use_blas=True, optimize=optimize)
contractions = tuple(data[:3] for data in contractions)
return _einsum(operands, contractions, precision)
@_wraps(np.einsum_path)
def einsum_path(subscripts, *operands, optimize='greedy'):
# using einsum_call=True here is an internal api for opt_einsum
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
def _removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))
@partial(jit, static_argnums=(1, 2))
def _einsum(operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], Set[str], str]],
precision):
operands = list(_promote_dtypes(*operands))
def sum(x, axes):
return lax.reduce(x, np.array(0, x.dtype),
lax.add if x.dtype != bool_ else lax.bitwise_or, axes)
def sum_uniques(operand, names, uniques):
if uniques:
axes = [names.index(name) for name in uniques]
operand = sum(operand, axes)
names = _removechars(names, uniques)
return operand, names
def sum_repeats(operand, names, counts, keep_names):
for name, count in counts.items():
if count > 1:
axes = [i for i, n in enumerate(names) if n == name]
eye = lax._delta(operand.dtype, operand.shape, axes)
if name not in keep_names:
operand = sum(operand * eye, axes)
names = names.replace(name, '')
else:
operand = sum(operand * eye, axes[:-1])
names = names.replace(name, '', count - 1)
return operand, names
def filter_singleton_dims(operand, names, other_shape, other_names):
s = shape(operand)
new_shape = []
new_names = []
for i, d in enumerate(names):
other_i = other_names.find(d)
if s[i] != 1 or other_i == -1 or other_shape[other_i] == 1:
new_shape.append(s[i])
new_names.append(d)
return reshape(operand, tuple(new_shape)), "".join(new_names)
for operand_indices, contracted_names_set, einstr in contractions:
contracted_names = sorted(contracted_names_set)
input_str, result_names = einstr.split('->')
input_names = input_str.split(',')
# switch on the number of operands to be processed in this loop iteration.
# every case here sets 'operand' and 'names'.
if len(operand_indices) == 1:
operand = operands.pop(operand_indices[0])
names, = input_names
counts = collections.Counter(names)
# sum out unique contracted indices with a single reduce-sum
uniques = [name for name in contracted_names if counts[name] == 1]
operand, names = sum_uniques(operand, names, uniques)
# for every repeated index, do a contraction against an identity matrix
operand, names = sum_repeats(operand, names, counts, result_names)
elif len(operand_indices) == 2:
lhs, rhs = map(operands.pop, operand_indices)
lhs_names, rhs_names = input_names
# handle cases where one side of a contracting or batch dimension is 1
# but its counterpart is not.
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
rhs_names)
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
lhs_names)
lhs_counts = collections.Counter(lhs_names)
rhs_counts = collections.Counter(rhs_names)
# sum out unique contracted indices in lhs and rhs
lhs_uniques = [name for name in contracted_names
if lhs_counts[name] == 1 and rhs_counts[name] == 0]
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
rhs_uniques = [name for name in contracted_names
if rhs_counts[name] == 1 and lhs_counts[name] == 0]
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
# for every repeated index, contract against an identity matrix
lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts,
result_names + rhs_names)
rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts,
result_names + lhs_names)
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
for n in batch_names)
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
# due to opt_einsum
assert _all(
name in lhs_names and name in rhs_names and
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
for name in contracted_names)
# contract using lax.dot_general
batch_names_str = ''.join(batch_names)
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
deleted_names = batch_names_str + ''.join(contracted_names)
names = (batch_names_str + _removechars(lhs_names, deleted_names)
+ _removechars(rhs_names, deleted_names))
else:
raise NotImplementedError # if this is actually reachable, open an issue!
# the resulting 'operand' with axis labels 'names' should be a permutation
# of the desired result
assert len(names) == len(result_names) == len(set(names))
assert set(names) == set(result_names)
if names != result_names:
perm = tuple([names.index(name) for name in result_names])
operand = lax.transpose(operand, perm)
operands.append(operand) # used in next iteration
return operands[0]
def _movechars(s, src, dst):
"""Helper for einsum string munging, like moveaxis on identifier strings."""
chars = [c for i, c in enumerate(s) if i not in src]
for i, j in sorted(zip(dst, src)):
chars.insert(i, s[j])
return ''.join(chars)
@_wraps(np.inner, lax_description=_PRECISION_DOC)
def inner(a, b, *, precision=None):
if ndim(a) == 0 or ndim(b) == 0:
return a * b
return tensordot(a, b, (-1, -1), precision=precision)
@_wraps(np.outer)
def outer(a, b, out=None):
if out:
raise NotImplementedError("The 'out' argument to outer is not supported.")
a, b = _promote_dtypes(a, b)
return ravel(a)[:, None] * ravel(b)[None, :]
@partial(jit, static_argnums=(2, 3, 4))
def _cross(a, b, axisa, axisb, axisc):
a = moveaxis(a, axisa, -1)
b = moveaxis(b, axisb, -1)
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError("Dimension must be either 2 or 3 for cross product")
if a.shape[-1] == 2 and b.shape[-1] == 2:
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
a0 = a[..., 0]
a1 = a[..., 1]
a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0)
b0 = b[..., 0]
b1 = b[..., 1]
b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0)
c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0])
return moveaxis(c, 0, axisc)
@_wraps(np.cross)
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if axis is not None:
axisa = axis
axisb = axis
axisc = axis
return _cross(a, b, axisa, axisb, axisc)
@_wraps(np.kron)
def kron(a, b):
a, b = _promote_dtypes(a, b)
if ndim(a) < ndim(b):
a = reshape(a, (1,) * (ndim(b) - ndim(a)) + shape(a))
elif ndim(b) < ndim(a):
b = reshape(b, (1,) * (ndim(a) - ndim(b)) + shape(b))
a_reshaped = reshape(a, [i for d in shape(a) for i in (d, 1)])
b_reshaped = reshape(b, [i for d in shape(b) for i in (1, d)])
out_shape = tuple(np.multiply(shape(a), shape(b)))
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
@_wraps(np.vander)
def vander(x, N=None, increasing=False):
x = asarray(x)
dtype = _dtype(x)
if ndim(x) != 1:
raise ValueError("x must be a one-dimensional array")
x_shape = shape(x)
N = N or x_shape[0]
if N < 0:
raise ValueError("N must be nonnegative")
iota = lax.iota(dtype, N)
if not increasing:
iota = lax.sub(lax._const(iota, N - 1), iota)
return power(x[..., None], iota)
### Misc
@_wraps(np.argwhere)
def argwhere(a):
result = transpose(vstack(nonzero(a)))
if ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))
@_wraps(np.argmax)
def argmax(a, axis=None):
if axis is None:
a = ravel(a)
axis = 0
if a.shape[axis] == 0:
raise ValueError("attempt to get argmax of an empty sequence")
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
@_wraps(np.argmin)
def argmin(a, axis=None):
if axis is None:
a = ravel(a)
axis = 0
if a.shape[axis] == 0:
raise ValueError("attempt to get argmin of an empty sequence")
return lax.argmin(a, _canonicalize_axis(axis, a.ndim), int64)
_NANARG_DOC = """\
Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise
an error.
"""
@_wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max"))
def nanargmax(a, axis=None):
if not issubdtype(_dtype(a), inexact):
return argmax(a, axis=axis)
nan_mask = isnan(a)
a = where(nan_mask, -inf, a)
res = argmax(a, axis=axis)
return where(all(nan_mask, axis=axis), -1, res)
@_wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"))
def nanargmin(a, axis=None):
if not issubdtype(_dtype(a), inexact):
return argmin(a, axis=axis)
nan_mask = isnan(a)
a = where(nan_mask, inf, a)
res = argmin(a, axis=axis)
return where(all(nan_mask, axis=axis), -1, res)
@_wraps(np.sort)
def sort(a, axis=-1, kind='quicksort', order=None):
if kind != 'quicksort':
warnings.warn("'kind' argument to sort is ignored.")
if order is not None:
raise ValueError("'order' argument to sort is not supported.")
if axis is None:
return lax.sort(a.ravel(), dimension=0)
else:
return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a)))
@_wraps(np.sort_complex)
def sort_complex(a):
a = lax.sort(a, dimension=0)
return lax.convert_element_type(a, result_type(a, dtypes.canonicalize_dtype(complex_)))
@_wraps(np.lexsort)
def lexsort(keys, axis=-1):
keys = tuple(keys)
if len(keys) == 0:
raise TypeError("need sequence of keys with len > 0 in lexsort")
if len({shape(key) for key in keys}) > 1:
raise ValueError("all keys need to be the same shape")
if ndim(keys[0]) == 0:
return np.int64(0)
axis = _canonicalize_axis(axis, ndim(keys[0]))
iota = lax.broadcasted_iota(np.int64, shape(keys[0]), axis)
return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]
@_wraps(np.argsort)
def argsort(a, axis=-1, kind='quicksort', order=None):
if kind != 'quicksort':
warnings.warn("'kind' argument to argsort is ignored.")
if order is not None:
raise ValueError("'order' argument to argsort is not supported.")
if axis is None:
return argsort(a.ravel(), 0)
else:
axis = _canonicalize_axis(axis, ndim(a))
iota = lax.broadcasted_iota(np.int64, shape(a), axis)
_, perm = lax.sort_key_val(a, iota, dimension=axis)
return perm
@_wraps(np.msort)
def msort(a):
return sort(a, axis=0)
@partial(jit, static_argnums=(2,))
def _roll(a, shift, axis):
a = asarray(a)
a_shape = shape(a)
if axis is None:
return lax.reshape(roll(ravel(a), shift, axis=0), a_shape)
a_ndim = len(a_shape)
shift = asarray(shift)
axis = np.asarray(axis)
b_shape = lax.broadcast_shapes(shift.shape, axis.shape, (1,))
if len(b_shape) != 1:
msg = "'shift' and 'axis' arguments to roll must be scalars or 1D arrays"
raise ValueError(msg)
for x, i in zip(broadcast_to(shift, b_shape),
np.broadcast_to(axis, b_shape)):
i = _canonicalize_axis(i, a_ndim)
x = remainder(x, (a_shape[i] or 1))
a = lax.concatenate((a, a), i)
a = lax.dynamic_slice_in_dim(a, a_shape[i] - x, a_shape[i], axis=i)
return a
@_wraps(np.roll)
def roll(a, shift, axis=None):
return _roll(a, shift, axis)
@_wraps(np.rollaxis)
def rollaxis(a, axis, start=0):
a_ndim = ndim(a)
axis = _canonicalize_axis(axis, a_ndim)
if not (-a_ndim <= start <= a_ndim):
raise ValueError(f"start={start} must satisfy {-a_ndim}<=start<={a_ndim}")
if start < 0:
start += a_ndim
if start > axis:
start -= 1
return moveaxis(a, axis, start)
@_wraps(np.packbits)
def packbits(a, axis=None, bitorder='big'):
a = asarray(a)
if not (issubdtype(dtype(a), integer) or issubdtype(dtype(a), bool_)):
raise TypeError('Expected an input array of integer or boolean data type')
if bitorder not in ['little', 'big']:
raise ValueError("'order' must be either 'little' or 'big'")
a = (a > 0).astype('uint8')
bits = arange(8, dtype='uint8')
if bitorder == 'big':
bits = bits[::-1]
if axis is None:
a = ravel(a)
axis = 0
a = swapaxes(a, axis, -1)
remainder = a.shape[-1] % 8
if remainder:
a = pad(a, (a.ndim - 1) * [(0, 0)] + [(0, 8 - remainder)])
a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
packed = (a << bits).sum(-1).astype('uint8')
return swapaxes(packed, axis, -1)
@_wraps(np.unpackbits)
def unpackbits(a, axis=None, count=None, bitorder='big'):
a = asarray(a)
if dtype(a) != uint8:
raise TypeError("Expected an input array of unsigned byte data type")
if bitorder not in ['little', 'big']:
raise ValueError("'order' must be either 'little' or 'big'")
bits = asarray(1) << arange(8, dtype='uint8')
if bitorder == 'big':
bits = bits[::-1]
if axis is None:
a = a.ravel()
axis = 0
a = swapaxes(a, axis, -1)
unpacked = ((a[..., None] & bits) > 0).astype('uint8')
unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count]
return swapaxes(unpacked, axis, -1)
@_wraps(np.take)
def take(a, indices, axis=None, out=None, mode=None):
if out:
raise NotImplementedError("The 'out' argument to np.take is not supported.")
a = asarray(a)
indices = asarray(indices)
if axis is None:
a = ravel(a)
axis = 0
axis = _canonicalize_axis(axis, ndim(a))
if mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to np.take is not supported.")
elif mode == "wrap":
indices = mod(indices, _constant_like(indices, a.shape[axis]))
elif mode != "clip" and mode is not None:
raise ValueError("Invalid mode '{}' for np.take".format(mode))
index_dims = len(shape(indices))
slice_sizes = list(shape(a))
slice_sizes[axis] = _min(indices.size, 1)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(
list(range(axis)) +
list(range(axis + index_dims, len(a.shape) + index_dims - 1))),
collapsed_slice_dims=(axis,),
start_index_map=(axis,))
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes))
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
if type(axis_size) is Poly:
return index + axis_size if index < 0 else index
return lax.select(
lax.lt(index, _constant_like(index, 0)),
lax.add(index, _constant_like(index, axis_size)),
index)
@partial(jit, static_argnums=(2,))
def _take_along_axis(arr, indices, axis):
if axis is None:
if ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
raise ValueError(msg.format(indices.shape))
return take_along_axis(arr.ravel(), indices, 0)
rank = ndim(arr)
if rank != ndim(indices):
msg = "indices and arr must have the same number of dimensions; {} vs. {}"
raise ValueError(msg.format(ndim(indices), ndim(arr)))
axis = _canonicalize_axis(axis, rank)
def replace(tup, val):
lst = list(tup)
lst[axis] = val
return tuple(lst)
bcast_shape = lax.broadcast_shapes(replace(arr.shape, 1), replace(indices.shape, 1))
indices = broadcast_to(indices, replace(bcast_shape, indices.shape[axis]))
arr = broadcast_to(arr, replace(bcast_shape, arr.shape[axis]))
axis_size = arr.shape[axis]
arr_shape = replace(arr.shape, 1)
idx_shape = indices.shape
out_shape = lax.broadcast_shapes(idx_shape, arr_shape)
index_dims = [i for i, idx in enumerate(idx_shape) if i == axis or idx != 1]
gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,)
gather_indices = []
slice_sizes = []
offset_dims = []
start_index_map = []
collapsed_slice_dims = []
j = 0
for i in range(rank):
if i == axis:
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
j += 1
elif idx_shape[i] != 1:
iota = lax.iota(_dtype(indices), out_shape[i])
if not config.omnistaging_enabled:
iota = lax.tie_in(arr, iota)
iota = lax.broadcast_in_dim(iota, gather_index_shape, (j,))
gather_indices.append(iota)
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
j += 1
else:
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
# and avoid forming an iota index.
offset_dims.append(i)
slice_sizes.append(arr_shape[i])
gather_indices = lax.concatenate(gather_indices, dimension=j)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(offset_dims),
collapsed_slice_dims=tuple(collapsed_slice_dims),
start_index_map=tuple(start_index_map))
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes))
@_wraps(getattr(np, "take_along_axis", None), update_doc=False)
def take_along_axis(arr, indices, axis):
return _take_along_axis(arr, indices, axis)
### SetOps
@partial(jit, static_argnums=1)
def _unique1d_sorted_mask(ar, optional_indices=False):
"""
Helper function for unique which is jit-able
"""
ar = asarray(ar).flatten()
if optional_indices:
perm = ar.argsort()
aux = ar[perm]
else:
aux = ar.sort()
mask = empty(aux.shape, dtype=bool_)
mask = ops.index_update(mask, ops.index[:1], True)
mask = ops.index_update(mask, ops.index[1:], aux[1:] != aux[:-1])
if optional_indices:
return aux, mask, perm
else:
return aux, mask
def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False):
"""
Find the unique elements of an array, ignoring shape.
"""
optional_indices = return_index or return_inverse
if optional_indices:
aux, mask, perm = _unique1d_sorted_mask(ar, optional_indices)
else:
aux, mask = _unique1d_sorted_mask(ar, optional_indices)
ret = (aux[mask],)
if return_index:
ret += (perm[mask],)
if return_inverse:
imask = cumsum(mask) - 1
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
inv_idx = ops.index_update(inv_idx, perm, imask)
ret += (inv_idx,)
if return_counts:
idx = concatenate(nonzero(mask) + (array([mask.size]),))
ret += (diff(idx),)
return ret
@_wraps(np.unique)
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis=None):
if iscomplexobj(ar):
raise NotImplementedError(
"np.unique is not implemented for complex valued arrays")
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts)
if len(ret) == 1:
return ret[0]
else:
return ret
raise NotImplementedError(
"np.unique is not implemented for the axis argument")
### Indexing
def _rewriting_take(arr, idx):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
arr = asarray(arr)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
return _gather(arr, treedef, static_idx, dynamic_idx)
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(1, 2))
def _gather(arr, treedef, static_idx, dynamic_idx):
idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
y = arr
# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if _prod(indexer.slice_shape) == 0:
return zeros(indexer.slice_shape, dtype=y.dtype)
# We avoid generating a gather when indexer.gather_indices.size is empty.
if indexer.gather_indices.size:
y = lax.gather(y, indexer.gather_indices, indexer.dnums,
indexer.gather_slice_shape)
# Reverses axes with negative strides.
if indexer.reversed_y_dims:
y = lax.rev(y, indexer.reversed_y_dims)
# This adds np.newaxis/None dimensions.
return expand_dims(y, indexer.newaxis_dims)
_Indexer = collections.namedtuple("_Indexer", [
# The expected shape of the slice output.
"slice_shape",
# The slice shape to pass to lax.gather().
"gather_slice_shape",
# The gather indices to use.
"gather_indices",
# A GatherDimensionNumbers object describing the gather to perform.
"dnums",
# Slice dimensions that have negative strides, and so must be reversed after
# the gather.
"reversed_y_dims",
# Keep track of any axes created by `newaxis`. These must be inserted for
# gathers and eliminated for scatters.
"newaxis_dims",
])
def _split_index_for_jit(idx):
"""Splits indices into necessarily-static and dynamic parts.
Used to pass indices into `jit`-ted function.
"""
# Convert list indices to tuples in cases (deprecated by NumPy.)
idx = _eliminate_deprecated_list_indexing(idx)
# Expand any (concrete) boolean indices. We can then use advanced integer
# indexing logic to handle them.
idx = _expand_bool_indices(idx)
leaves, treedef = tree_flatten(idx)
dynamic = [None] * len(leaves)
static = [None] * len(leaves)
for i, x in enumerate(leaves):
if x is Ellipsis:
static[i] = x
elif isinstance(x, slice):
# slice objects aren't hashable.
static[i] = (x.start, x.stop, x.step)
else:
dynamic[i] = x
return treedef, tuple(static), dynamic
def _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
"""Recombines indices that were split by _split_index_for_jit."""
idx = []
for s, d in zip(static_idx, dynamic_idx):
if d is not None:
idx.append(d)
elif isinstance(s, tuple):
idx.append(slice(s[0], s[1], s[2]))
else:
idx.append(s)
return treedef.unflatten(idx)
def _int(aval):
return not aval.shape and issubdtype(aval.dtype, integer)
def _index_to_gather(x_shape, idx):
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
# move the advanced axes to the front.
advanced_axes_are_contiguous = False
advanced_indexes = None
# The positions of the advanced indexing axes in `idx`.
idx_advanced_axes = []
# The positions of the advanced indexes in x's shape.
# collapsed, after None axes have been removed. See below.
x_advanced_axes = None
if _is_advanced_int_indexer(idx):
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
advanced_pairs = (
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if isinstance(e, (Sequence, ndarray)))
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
advanced_axes_are_contiguous = np.all(np.diff(idx_advanced_axes) == 1)
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.
collapsed_y_axis = 0 # Current axis in y, after collapsing.
# Scatter dimension numbers.
offset_dims = []
collapsed_slice_dims = []
start_index_map = []
use_64bit_index = _any([type(d) is Poly or d >= (1 << 31) for d in x_shape])
index_dtype = int64 if use_64bit_index else int32
gather_indices = np.zeros((0,), dtype=index_dtype) # use np to save a compilation
# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
# the right shape.
slice_shape = []
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
# indices, which the scatter cannot remove itself.
newaxis_dims = []
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims = []
gather_slice_shape = []
for idx_pos, i in enumerate(idx):
# Handle the advanced indices here if:
# * the advanced indices were not contiguous and we are the start.
# * we are at the position of the first advanced index.
if (advanced_indexes is not None and
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
not advanced_axes_are_contiguous and idx_pos == 0)):
advanced_indexes = broadcast_arrays(*advanced_indexes)
shape = advanced_indexes[0].shape
ndim = len(shape)
advanced_indexes = [
lax.convert_element_type(lax.reshape(a, shape + (1,)), index_dtype)
for a in advanced_indexes]
# Broadcast gather_indices from [..., k] to [..., 1, 1, ..., 1, k].
gather_indices = lax.broadcast_in_dim(
gather_indices, np.insert(gather_indices.shape, -1, shape),
tuple(range(gather_indices.ndim - 1)) + (gather_indices.ndim + ndim - 1,))
gather_indices = concatenate([gather_indices] + advanced_indexes, -1)
start_index_map.extend(x_advanced_axes)
collapsed_slice_dims.extend(x_advanced_axes)
slice_shape.extend(shape)
y_axis += ndim
collapsed_y_axis += ndim
# Per-index bookkeeping for advanced indexes.
if idx_pos in idx_advanced_axes:
x_axis += 1
gather_slice_shape.append(1)
continue
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
# Handle basic int indexes.
if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
if x_shape[x_axis] == 0:
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis])
if type(i) is Poly:
# dummy index if i is polynomial, doesn't matter for shape inference
# TODO(mattjj,j-towns,juliuskunze): revise this logic
i = 0
i = lax.convert_element_type(i, index_dtype)
i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,))
gather_indices = concatenate((gather_indices, i), -1)
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
x_axis += 1
# Handle np.newaxis (None)
elif i is None:
slice_shape.append(1)
newaxis_dims.append(y_axis)
y_axis += 1
# Handle slice(None)
elif _is_slice_none(i):
slice_shape.append(x_shape[x_axis])
gather_slice_shape.append(x_shape[x_axis])
offset_dims.append(collapsed_y_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
# Handle slice index (only static, otherwise an error is raised)
elif isinstance(i, slice):
if not _all(elt is None or type(elt) is Poly
or type(core.get_aval(elt)) is ConcreteArray
for elt in (i.start, i.stop, i.step)):
msg = ("Array slice indices must have static start/stop/step to be used "
"with NumPy indexing syntax. To index a statically sized "
"array at a dynamic position, try lax.dynamic_slice/"
"dynamic_update_slice (JAX does not support dynamically sized "
"arrays within JIT compiled functions).")
raise IndexError(msg)
start, limit, stride, needs_rev = _static_idx(i, x_shape[x_axis])
if needs_rev:
reversed_y_dims.append(collapsed_y_axis)
if stride == 1:
i = lax.convert_element_type(start, index_dtype)
i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,))
gather_indices = concatenate((gather_indices, i), -1)
slice_shape.append(limit - start)
gather_slice_shape.append(limit - start)
offset_dims.append(collapsed_y_axis)
start_index_map.append(x_axis)
else:
i = arange(start, limit, stride, dtype=index_dtype)
size = i.shape[0]
slice_shape.append(size)
gather_slice_shape.append(1)
gather_indices_shape = tuple(gather_indices.shape[:-1]) + (size,)
i = lax.broadcast_in_dim(
i, shape=gather_indices_shape + (1,),
broadcast_dimensions=(len(gather_indices_shape) - 1,))
gather_indices = lax.broadcast_in_dim(
gather_indices,
shape=gather_indices_shape + (len(start_index_map),),
broadcast_dimensions=(
tuple(range(len(gather_indices_shape) - 1)) +
(len(gather_indices_shape),)))
gather_indices = concatenate(
(gather_indices, i), len(gather_indices_shape))
start_index_map.append(x_axis)
collapsed_slice_dims.append(x_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
else:
if (abstract_i is not None and
not (issubdtype(abstract_i.dtype, integer) or issubdtype(abstract_i.dtype, bool_))):
msg = ("Indexer must have integer or boolean type, got indexer "
"with type {} at position {}, indexer value {}")
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
msg = "Indexing mode not yet supported. Open a feature request!\n{}"
raise IndexError(msg.format(idx))
dnums = lax.GatherDimensionNumbers(
offset_dims = tuple(offset_dims),
collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)),
start_index_map = tuple(start_index_map)
)
return _Indexer(
slice_shape=slice_shape,
newaxis_dims=tuple(newaxis_dims),
gather_slice_shape=gather_slice_shape,
reversed_y_dims=reversed_y_dims,
dnums=dnums,
gather_indices=gather_indices)
def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
return (isinstance(x, ndarray) and np.ndim(x) != 0
or isinstance(x, (Sequence, slice))
or x is Ellipsis or x is None)
def _eliminate_deprecated_list_indexing(idx):
# "Basic slicing is initiated if the selection object is a non-array,
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this case and canonicalizes to a tuple. This case is
# deprecated by NumPy and exists for backward compatibility.
if not isinstance(idx, tuple):
if isinstance(idx, Sequence) and not isinstance(idx, ndarray):
if _any(_should_unpack_list_index(i) for i in idx):
idx = tuple(idx)
else:
idx = (idx,)
else:
idx = (idx,)
return idx
def _expand_bool_indices(idx):
"""Converts concrete bool indexes into advanced integer indexes."""
out = []
for i in idx:
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
or isinstance(i, list) and _all(not _shape(e) and issubdtype(_dtype(e), bool_)
for e in i)):
if isinstance(i, list):
i = array(i)
abstract_i = core.get_aval(i)
if not type(abstract_i) is ConcreteArray:
# TODO(mattjj): improve this error by tracking _why_ the indices are not
# concrete
raise IndexError("Array boolean indices must be concrete.")
else:
out.extend(np.where(i))
else:
out.append(i)
return tuple(out)
def _is_slice_none(idx):
"""Return True if idx is equal to slice(None), False otherwise."""
if isinstance(idx, slice):
return idx.start is None and idx.stop is None and idx.step is None
# TODO(mattjj): clean up this logic
def _is_advanced_int_indexer(idx):
"""Returns True if idx should trigger int array indexing, False otherwise."""
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
assert isinstance(idx, tuple)
if _all(np.ndim(elt) == 0 for elt in idx):
return False
return _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_int_arraylike(e) for e in idx)
def _is_int_arraylike(x):
"""Returns True if x is array-like with integer dtype, False otherwise."""
return (isinstance(x, int) and not isinstance(x, bool)
or issubdtype(getattr(x, "dtype", None), np.integer)
or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x))
def _canonicalize_tuple_index(arr_ndim, idx):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis)
if len_without_none > arr_ndim:
msg = "Too many indices for array: {} non-None/Ellipsis indices for dim {}."
raise IndexError(msg.format(len_without_none, arr_ndim))
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:
if next(ellipses, None) is not None:
msg = "Multiple ellipses (...) not supported: {}."
raise IndexError(msg.format(list(map(type, idx))))
colons = (slice(None),) * (arr_ndim - len_without_none)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr_ndim:
colons = (slice(None),) * (arr_ndim - len_without_none)
idx = tuple(idx) + colons
return idx
def _polymorphic_slice_indices(idx: slice, size: Union[int, Poly]):
# like idx.indices(size), but allows for polymorphic indices and size
# see https://github.com/python/cpython/blob/6d6508765514c7c10719478a0430f5e47c9a96ac/Objects/sliceobject.c#L372
assert isinstance(idx, slice)
step = 1 if idx.step is None else idx.step
step_is_negative = step < 0
lower = -1 if step_is_negative else 0
upper = size + lower
def sanitize(index, default):
if index is None:
return default
elif type(index) is Poly:
return index
elif index < 0:
return _max(index + size, lower)
else:
return _min(index, upper)
start = sanitize(idx.start, default=upper if step_is_negative else lower)
stop = sanitize(idx.stop, default=lower if step_is_negative else upper)
return start, stop, step
def _static_idx(idx: slice, size: Union[int, Poly]):
"""Helper function to compute the static slice start/limit/stride values."""
if _any(type(s) is Poly for s in (idx.start, idx.stop, idx.step, size)):
start, stop, step = _polymorphic_slice_indices(idx, size)
elif isinstance(size, int):
start, stop, step = idx.indices(size)
else:
raise TypeError(size)
if type(start) is not Poly and type(stop) is not Poly:
if (step < 0 and stop >= start) or (step > 0 and start >= stop):
return 0, 0, 1, False # sliced to size zero
if step > 0:
return start, stop, step, False
else:
k = (start - stop - 1) % (-step)
return stop + k + 1, start + 1, -step, True
blackman = _wrap_numpy_nullary_function(np.blackman)
bartlett = _wrap_numpy_nullary_function(np.bartlett)
hamming = _wrap_numpy_nullary_function(np.hamming)
hanning = _wrap_numpy_nullary_function(np.hanning)
# TODO: lower `kaiser` via lax to allow non-constant beta values.
kaiser = _wrap_numpy_nullary_function(np.kaiser)
def _gcd_cond_fn(xs):
x1, x2 = xs
return any(x2 != 0)
def _gcd_body_fn(xs):
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
@_wraps(getattr(np, "gcd", None))
def gcd(x1, x2):
if (not issubdtype(_dtype(x1), integer) or
not issubdtype(_dtype(x2), integer)):
raise ValueError("Arguments to jax.numpy.gcd must be integers.")
x1, x2 = _promote_dtypes(x1, x2)
x1, x2 = broadcast_arrays(x1, x2)
gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2)))
return gcd
@_wraps(getattr(np, "lcm", None))
def lcm(x1, x2):
x1, x2 = _promote_dtypes(x1, x2)
d = gcd(x1, x2)
return where(d == 0, lax._const(d, 0),
abs(multiply(x1, floor_divide(x2, d))))
@_wraps(np.extract)
def extract(condition, arr):
return compress(ravel(condition), ravel(arr))
@_wraps(np.compress)
def compress(condition, a, axis=None, out=None):
if out is not None:
raise NotImplementedError("out argument is not supported.")
if ndim(condition) != 1:
raise ValueError("condition must be a 1D array")
condition = array(condition).astype(bool)
a = array(a)
if axis is None:
axis = 0
a = ravel(a)
else:
a = moveaxis(a, axis, 0)
condition, extra = condition[:a.shape[0]], condition[a.shape[0]:]
if any(extra):
raise ValueError("condition contains entries that are out of bounds")
a = a[:condition.shape[0]]
return moveaxis(a[condition], 0, axis)
@_wraps(np.cov)
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
aweights=None):
msg = ("jax.numpy.cov not implemented for nontrivial {}. "
"Open a feature request at https://github.com/google/jax/issues !")
if y is not None: raise NotImplementedError(msg.format('y'))
# These next two are actually implemented, just not tested.
if fweights is not None: raise NotImplementedError(msg.format('fweights'))
if aweights is not None: raise NotImplementedError(msg.format('aweights'))
if m.ndim > 2:
raise ValueError("m has more than 2 dimensions") # same as numpy error
X = array(m, ndmin=2, dtype=dtypes.canonicalize_dtype(result_type(m, float_)))
if not rowvar and X.shape[0] != 1:
X = X.T
if X.shape[0] == 0:
return array([]).reshape(0, 0)
if ddof is None:
ddof = 1 if bias == 0 else 0
w = None
if fweights is not None:
if np.ndim(fweights) > 1:
raise RuntimeError("cannot handle multidimensional fweights")
if np.shape(fweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and fweights")
w = asarray(fweights)
if aweights is not None:
if np.ndim(aweights) > 1:
raise RuntimeError("cannot handle multidimensional aweights")
if np.shape(aweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and aweights")
w = aweights if w is None else w * aweights
avg, w_sum = average(X, axis=1, weights=w, returned=True)
w_sum = w_sum[0]
if w is None:
f = X.shape[1] - ddof
elif ddof == 0:
f = w_sum
elif aweights is None:
f = w_sum - ddof
else:
f = w_sum - ddof * sum(w * aweights) / w_sum
X = X - avg[:, None]
X_T = X.T if w is None else (X * w).T
return true_divide(dot(X, X_T.conj()), f).squeeze()
@_wraps(np.corrcoef)
def corrcoef(x, y=None, rowvar=True):
c = cov(x, y, rowvar)
if len(shape(c)) == 0:
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
return divide(c, c)
d = diag(c)
stddev = sqrt(real(d))
c = divide(c, stddev[:,None])
c = divide(c, stddev[None,:])
real_part = clip(real(c), -1, 1)
if iscomplexobj(c):
complex_part = clip(imag(c), -1, 1)
c = lax.complex(real_part, complex_part)
else:
c = real_part
return c
@_wraps(getattr(np, "quantile", None))
def quantile(a, q, axis=None, out=None, overwrite_input=False,
interpolation="linear", keepdims=False):
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
return _quantile(a, q, axis, interpolation, keepdims, False)
@_wraps(getattr(np, "nanquantile", None))
def nanquantile(a, q, axis=None, out=None, overwrite_input=False,
interpolation="linear", keepdims=False):
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
return _quantile(a, q, axis, interpolation, keepdims, True)
@partial(jit, static_argnums=(2, 3, 4, 5))
def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a = asarray(a, dtype=promote_types(_dtype(a), float32))
q = asarray(q, dtype=promote_types(_dtype(q), float32))
if axis is None:
a = ravel(a)
axis = 0
elif isinstance(axis, tuple):
raise NotImplementedError("Tuple values for axis are not implemented")
else:
axis = _canonicalize_axis(axis, ndim(a))
q_shape = shape(q)
q_ndim = ndim(q)
if q_ndim > 1:
raise ValueError("q must be have rank <= 1, got shape {}".format(shape(q)))
a_shape = shape(a)
a = lax.sort(a, dimension=axis)
if squash_nans:
counts = sum(logical_not(isnan(a)), axis=axis, dtype=q.dtype,
keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _constant_like(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_constant_like(high_weight, 1), high_weight)
low = lax.max(_constant_like(low, 0), lax.min(low, counts - 1))
high = lax.max(_constant_like(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
n = a_shape[axis]
q = lax.mul(q, _constant_like(q, n - 1))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_constant_like(high_weight, 1), high_weight)
low = lax.clamp(_constant_like(low, 0), low, _constant_like(low, n - 1))
high = lax.clamp(_constant_like(high, 0), high, _constant_like(high, n - 1))
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))
if interpolation == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
result = low_value
elif interpolation == "higher":
result = high_value
elif interpolation == "nearest":
pred = lax.le(high_weight, _constant_like(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _constant_like(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
return lax.convert_element_type(result, a.dtype)
@partial(jit, static_argnums=2)
@partial(vectorize, excluded={0, 2})
def _searchsorted(a, v, side):
if len(a) == 0:
return 0
op = operator.le if side == 'left' else operator.lt
def body_fun(i, state):
low, high = state
mid = (low + high) // 2
go_left = op(v, a[mid])
return (where(go_left, low, mid), where(go_left, mid, high))
n_levels = int(np.ceil(np.log2(len(a) + 1)))
return lax.fori_loop(0, n_levels, body_fun, (0, len(a)))[1]
@_wraps(np.searchsorted)
def searchsorted(a, v, side='left', sorter=None):
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'")
if sorter is not None:
raise NotImplementedError("sorter is not implemented")
a = asarray(a)
v = asarray(v)
if ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
return _searchsorted(a, v, side)
@_wraps(np.digitize)
def digitize(x, bins, right=False):
if len(bins) == 0:
return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
side = 'right' if not right else 'left'
return where(
bins[-1] >= bins[0],
searchsorted(bins, x, side=side),
len(bins) - searchsorted(bins[::-1], x, side=side)
)
_PIECEWISE_DOC = """\
Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in
`funclist` to be traceable by JAX, as it is implemeted via :func:`jax.lax.switch`.
See the :func:`jax.lax.switch` documentation for more information.
"""
@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
def piecewise(x, condlist, funclist, *args, **kw):
condlist = array(condlist, dtype=bool_)
nc, nf = len(condlist), len(funclist)
if nf == nc + 1:
funclist = funclist[-1:] + funclist[:-1]
elif nf == nc:
funclist = [0] + list(funclist)
else:
raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
indices = argmax(cumsum(vstack([zeros_like(condlist[:1]), condlist]), 0), 0)
dtype = _dtype(x)
def _call(f):
return lambda x: f(x, *args, **kw).astype(dtype)
def _const(v):
return lambda x: full_like(x, v)
funclist = [_call(f) if callable(f) else _const(f) for f in funclist]
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
@_wraps(np.percentile)
def percentile(a, q, axis=None, out=None, overwrite_input=False,
interpolation="linear", keepdims=False):
q = true_divide(asarray(q), float32(100.0))
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
@_wraps(np.nanpercentile)
def nanpercentile(a, q, axis=None, out=None, overwrite_input=False,
interpolation="linear", keepdims=False):
q = true_divide(asarray(q), float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
@_wraps(np.median)
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, interpolation='midpoint')
@_wraps(np.nanmedian)
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
interpolation='midpoint')
def _astype(arr, dtype):
lax._check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(arr, dtype)
def _nbytes(arr):
return size(arr) * _dtype(arr).itemsize
def _view(arr, dtype=None, type=None):
if type is not None:
raise NotImplementedError("`type` argument of array.view()")
if dtype is None:
return arr
arr_dtype = _dtype(arr)
if arr_dtype == dtype:
return arr
# bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type.
# We work around this by casting bool to uint8.
if arr_dtype == bool_:
arr = arr.astype(uint8)
nbits_in = 8 * arr_dtype.itemsize
nbits_out = 8 * _dtype(dtype).itemsize
if nbits_in == nbits_out:
if dtype == bool_:
return lax.bitcast_convert_type(arr, uint8).astype(dtype)
return lax.bitcast_convert_type(arr, dtype)
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
byte_dtypes = {8: uint8, 16: uint16, 32: uint32, 64: uint64}
if nbits_in not in byte_dtypes:
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
if nbits_out not in byte_dtypes:
raise NotImplementedError(f"arr.view(dtype) for dtype={dtype}")
dt_in = byte_dtypes[nbits_in]
dt_out = byte_dtypes[nbits_out]
arr_bytes = lax.bitcast_convert_type(arr, dt_in)
if nbits_in < nbits_out:
shifts = arange(0, nbits_out, nbits_in, dtype=dt_out)
arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)
else:
shifts = arange(0, nbits_in, nbits_out, dtype=dt_in)
arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
if dtype == bool_:
return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
return lax.bitcast_convert_type(arr_bytes, dtype)
### track unimplemented functions
_NOT_IMPLEMENTED_DESC = """
*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError ***
"""
def _not_implemented(fun):
@_wraps(fun, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
def wrapped(*args, **kwargs):
msg = "Numpy function {} not yet implemented"
raise NotImplementedError(msg.format(fun))
return wrapped
### add method and operator overloads to arraylike classes
# We add operator overloads to DeviceArray and ShapedArray. These method and
# operator overloads mainly just forward calls to the corresponding lax_numpy
# functions, which can themselves handle instances from any of these classes.
_scalar_types = (int, float, complex, np.generic)
def _defer_to_unrecognized_arg(binary_op):
# Ensure that other array types have the chance to override arithmetic.
def deferring_binary_op(self, other):
if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
return NotImplemented
return binary_op(self, other)
return deferring_binary_op
def _swap_args(f):
return lambda x, y: f(y, x)
def _unimplemented_setitem(self, i, x):
msg = ("'{}' object does not support item assignment. JAX arrays are "
"immutable; perhaps you want jax.ops.index_update or "
"jax.ops.index_add instead?")
raise TypeError(msg.format(type(self)))
def _operator_round(number, ndigits=None):
out = round(number, decimals=ndigits or 0)
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
return out.astype(int_) if ndigits is None else out
_operators = {
"getitem": _rewriting_take,
"setitem": _unimplemented_setitem,
"neg": negative,
"pos": positive,
"eq": _defer_to_unrecognized_arg(equal),
"ne": _defer_to_unrecognized_arg(not_equal),
"lt": _defer_to_unrecognized_arg(less),
"le": _defer_to_unrecognized_arg(less_equal),
"gt": _defer_to_unrecognized_arg(greater),
"ge": _defer_to_unrecognized_arg(greater_equal),
"abs": abs,
"add": _defer_to_unrecognized_arg(add),
"radd": _defer_to_unrecognized_arg(add),
"sub": _defer_to_unrecognized_arg(subtract),
"rsub": _defer_to_unrecognized_arg(_swap_args(subtract)),
"mul": _defer_to_unrecognized_arg(multiply),
"rmul": _defer_to_unrecognized_arg(multiply),
"div": _defer_to_unrecognized_arg(divide),
"rdiv": _defer_to_unrecognized_arg(_swap_args(divide)),
"truediv": _defer_to_unrecognized_arg(true_divide),
"rtruediv": _defer_to_unrecognized_arg(_swap_args(true_divide)),
"floordiv": _defer_to_unrecognized_arg(floor_divide),
"rfloordiv": _defer_to_unrecognized_arg(_swap_args(floor_divide)),
"divmod": _defer_to_unrecognized_arg(divmod),
"rdivmod": _defer_to_unrecognized_arg(_swap_args(divmod)),
"mod": _defer_to_unrecognized_arg(mod),
"rmod": _defer_to_unrecognized_arg(_swap_args(mod)),
"pow": _defer_to_unrecognized_arg(power),
"rpow": _defer_to_unrecognized_arg(_swap_args(power)),
"matmul": _defer_to_unrecognized_arg(matmul),
"rmatmul": _defer_to_unrecognized_arg(_swap_args(matmul)),
"and": _defer_to_unrecognized_arg(bitwise_and),
"rand": _defer_to_unrecognized_arg(bitwise_and),
"or": _defer_to_unrecognized_arg(bitwise_or),
"ror": _defer_to_unrecognized_arg(bitwise_or),
"xor": _defer_to_unrecognized_arg(bitwise_xor),
"rxor": _defer_to_unrecognized_arg(bitwise_xor),
"invert": bitwise_not,
"lshift": _defer_to_unrecognized_arg(left_shift),
"rshift": _defer_to_unrecognized_arg(right_shift),
"rlshift": _defer_to_unrecognized_arg(_swap_args(left_shift)),
"rrshift": _defer_to_unrecognized_arg(_swap_args(right_shift)),
"round": _operator_round,
}
# These numpy.ndarray methods are just refs to an equivalent numpy function
_nondiff_methods = ["all", "any", "argmax", "argmin", "argpartition", "argsort",
"nonzero", "searchsorted", "round"]
_diff_methods = ["clip", "conj", "conjugate", "cumprod", "cumsum",
"diagonal", "dot", "max", "mean", "min", "prod", "ptp",
"ravel", "repeat", "sort", "squeeze", "std", "sum",
"swapaxes", "take", "tile", "trace", "transpose", "var"]
# These methods are mentioned explicitly by nondiff_methods, so we create
# _not_implemented implementations of them here rather than in __init__.py.
# TODO(phawkins): implement these.
argpartition = _not_implemented(np.argpartition)
_NOT_IMPLEMENTED = ['argpartition']
# Set up operator, method, and property forwarding on Tracer instances containing
# ShapedArray avals by following the forwarding conventions for Tracer.
# Forward operators using a single-underscore-prefix naming convention:
for operator_name, function in _operators.items():
setattr(ShapedArray, "_{}".format(operator_name), staticmethod(function))
# Forward methods and properties using core.aval_method and core.aval_property:
for method_name in _nondiff_methods + _diff_methods:
setattr(ShapedArray, method_name, core.aval_method(globals()[method_name]))
setattr(ShapedArray, "reshape", core.aval_method(_reshape_method))
setattr(ShapedArray, "flatten", core.aval_method(ravel))
setattr(ShapedArray, "T", core.aval_property(transpose))
setattr(ShapedArray, "real", core.aval_property(real))
setattr(ShapedArray, "imag", core.aval_property(imag))
setattr(ShapedArray, "astype", core.aval_method(_astype))
setattr(ShapedArray, "view", core.aval_method(_view))
setattr(ShapedArray, "nbytes", core.aval_property(_nbytes))
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
for operator_name, function in _operators.items():
setattr(DeviceArray, "__{}__".format(operator_name), function)
for method_name in _nondiff_methods + _diff_methods:
setattr(DeviceArray, method_name, globals()[method_name])
setattr(DeviceArray, "reshape", _reshape_method)
setattr(DeviceArray, "flatten", ravel)
setattr(DeviceArray, "T", property(transpose))
setattr(DeviceArray, "real", property(real))
setattr(DeviceArray, "imag", property(imag))
setattr(DeviceArray, "astype", _astype)
setattr(DeviceArray, "view", _view)
setattr(DeviceArray, "nbytes", property(_nbytes))
# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer)
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
def __array_module__(self, types):
if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types):
return jax.numpy
else:
return NotImplemented
setattr(ShapedArray, "_array_module", staticmethod(__array_module__))
setattr(DeviceArray, "__array_module__", __array_module__)
# Extra methods that are handy
setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))
setattr(ShapedArray, "split", core.aval_method(split))
setattr(DeviceArray, "broadcast", lax.broadcast)
setattr(DeviceArray, "broadcast_in_dim", lax.broadcast_in_dim)
setattr(DeviceArray, "split", split)
def _compress_method(a, condition, axis=None, out=None):
return compress(condition, a, axis, out)
setattr(ShapedArray, "compress", _compress_method)
setattr(DeviceArray, "compress", _compress_method)
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr: DeviceArray,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]):
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
DeviceArray method here to avoid circular imports.
"""
results = []
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed:
sliced = sliced.reshape(np.delete(sliced.shape, removed_dims))
results.append(sliced)
return results
setattr(DeviceArray, "_multi_slice", _multi_slice)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
# Note: this docstring will appear as the docstring for the `at` property.
"""Indexable helper object to call indexed update functions.
The `at` property is syntactic sugar for calling the indexed update functions
defined in :mod:`jax.ops`, and acts as a pure equivalent of in-place
modificatons.
In particular:
- ``x = x.at[idx].set(y)`` is a pure equivalent of ``x[idx] = y``.
- ``x = x.at[idx].add(y)`` is a pure equivalent of ``x[idx] += y``.
- ``x = x.at[idx].mul(y)`` is a pure equivalent of ``x[idx] *= y``.
- ``x = x.at[idx].min(y)`` is a pure equivalent of
``x[idx] = minimum(x[idx], y)``.
- ``x = x.at[idx].max(y)`` is a pure equivalent of
``x[idx] = maximum(x[idx], y)``.
"""
__slots__ = ("array",)
def __init__(self, array):
self.array = array
def __getitem__(self, index):
return _IndexUpdateRef(self.array, index)
def __repr__(self):
return f"_IndexUpdateHelper({repr(self.array)})"
class _IndexUpdateRef:
"""Helper object to call indexed update functions for an (advanced) index.
This object references a source array and a specific indexer into that array.
Methods on this object return copies of the source array that have been
modified at the positions specified by the indexer.
"""
__slots__ = ("array", "index")
def __init__(self, array, index):
self.array = array
self.index = index
def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
def set(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = y``.
``x.at[idx].set(y)`` is syntactic sugar for
``jax.ops.index_update(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
See :mod:`jax.ops` for details.
"""
return ops.index_update(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def add(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] += y``.
``x.at[idx].add(y)`` is syntactic sugar for
``jax.ops.index_add(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``.
See :mod:`jax.ops` for details.
"""
return ops.index_add(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def mul(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] += y``.
``x.at[idx].mul(y)`` is syntactic sugar for
``jax.ops.index_mul(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``.
See :mod:`jax.ops` for details.
"""
return ops.index_mul(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def min(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
``x.at[idx].min(y)`` is syntactic sugar for
``jax.ops.index_min(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = minimum(x[idx], y)``.
See :mod:`jax.ops` for details.
"""
return ops.index_min(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def max(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
``x.at[idx].max(y)`` is syntactic sugar for
``jax.ops.index_max(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = maximum(x[idx], y)``.
See :mod:`jax.ops` for details.
"""
return ops.index_max(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
setattr(DeviceArray, "at", property(_IndexUpdateHelper))
setattr(ShapedArray, "at", core.aval_property(_IndexUpdateHelper))