Rollback of:

d09d7b8d1363eab1c14051eb2376e605366537f9 by Jake VanderPlas <jakevdp@google.com>:

Factor-out pieces of lax_numpy.py

PiperOrigin-RevId: 431833044
This commit is contained in:
jax authors 2022-03-01 19:38:50 -08:00
parent 4755dc3fee
commit 3766dd2120
4 changed files with 1049 additions and 1148 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,295 +0,0 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ndarray is defined as an virtual abstract base class.
import abc
from typing import Any, Optional, Tuple, Union
from jax import core
from jax.interpreters import pxla
from jax._src import device_array
import numpy as np
class ArrayMeta(abc.ABCMeta):
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self, instance):
# Allow tracer instances with avals that are instances of UnshapedArray.
# We could instead just declare Tracer an instance of the ndarray type, but
# there can be traced values that are not arrays. The main downside here is
# that isinstance(x, ndarray) might return true but
# issubclass(type(x), ndarray) might return false for an array tracer.
try:
return (hasattr(instance, "aval") and
isinstance(instance.aval, core.UnshapedArray))
except AttributeError:
super().__instancecheck__(instance)
class ndarray(metaclass=ArrayMeta):
dtype: np.dtype
ndim: int
shape: Tuple[int, ...]
size: int
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
@abc.abstractmethod
def __getitem__(self, key, indices_are_sorted=False,
unique_indices=False) -> Any: ...
@abc.abstractmethod
def __setitem__(self, key, value) -> Any: ...
@abc.abstractmethod
def __len__(self) -> Any: ...
@abc.abstractmethod
def __iter__(self) -> Any: ...
@abc.abstractmethod
def __reversed__(self) -> Any: ...
# Comparisons
@abc.abstractmethod
def __lt__(self, other) -> Any: ...
@abc.abstractmethod
def __le__(self, other) -> Any: ...
@abc.abstractmethod
def __eq__(self, other) -> Any: ...
@abc.abstractmethod
def __ne__(self, other) -> Any: ...
@abc.abstractmethod
def __gt__(self, other) -> Any: ...
@abc.abstractmethod
def __ge__(self, other) -> Any: ...
# Unary arithmetic
@abc.abstractmethod
def __neg__(self) -> Any: ...
@abc.abstractmethod
def __pos__(self) -> Any: ...
@abc.abstractmethod
def __abs__(self) -> Any: ...
@abc.abstractmethod
def __invert__(self) -> Any: ...
# Binary arithmetic
@abc.abstractmethod
def __add__(self, other) -> Any: ...
@abc.abstractmethod
def __sub__(self, other) -> Any: ...
@abc.abstractmethod
def __mul__(self, other) -> Any: ...
@abc.abstractmethod
def __matmul__(self, other) -> Any: ...
@abc.abstractmethod
def __truediv__(self, other) -> Any: ...
@abc.abstractmethod
def __floordiv__(self, other) -> Any: ...
@abc.abstractmethod
def __mod__(self, other) -> Any: ...
@abc.abstractmethod
def __divmod__(self, other) -> Any: ...
@abc.abstractmethod
def __pow__(self, other) -> Any: ...
@abc.abstractmethod
def __lshift__(self, other) -> Any: ...
@abc.abstractmethod
def __rshift__(self, other) -> Any: ...
@abc.abstractmethod
def __and__(self, other) -> Any: ...
@abc.abstractmethod
def __xor__(self, other) -> Any: ...
@abc.abstractmethod
def __or__(self, other) -> Any: ...
@abc.abstractmethod
def __radd__(self, other) -> Any: ...
@abc.abstractmethod
def __rsub__(self, other) -> Any: ...
@abc.abstractmethod
def __rmul__(self, other) -> Any: ...
@abc.abstractmethod
def __rmatmul__(self, other) -> Any: ...
@abc.abstractmethod
def __rtruediv__(self, other) -> Any: ...
@abc.abstractmethod
def __rfloordiv__(self, other) -> Any: ...
@abc.abstractmethod
def __rmod__(self, other) -> Any: ...
@abc.abstractmethod
def __rdivmod__(self, other) -> Any: ...
@abc.abstractmethod
def __rpow__(self, other) -> Any: ...
@abc.abstractmethod
def __rlshift__(self, other) -> Any: ...
@abc.abstractmethod
def __rrshift__(self, other) -> Any: ...
@abc.abstractmethod
def __rand__(self, other) -> Any: ...
@abc.abstractmethod
def __rxor__(self, other) -> Any: ...
@abc.abstractmethod
def __ror__(self, other) -> Any: ...
@abc.abstractmethod
def __bool__(self) -> Any: ...
@abc.abstractmethod
def __complex__(self) -> Any: ...
@abc.abstractmethod
def __int__(self) -> Any: ...
@abc.abstractmethod
def __float__(self) -> Any: ...
@abc.abstractmethod
def __round__(self, ndigits=None) -> Any: ...
@abc.abstractmethod
def __index__(self) -> Any: ...
# np.ndarray methods:
@abc.abstractmethod
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Any: ...
@abc.abstractmethod
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None) -> Any: ...
@abc.abstractmethod
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
@abc.abstractmethod
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
@abc.abstractmethod
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ...
@abc.abstractmethod
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
@abc.abstractmethod
def astype(self, dtype) -> Any: ...
@abc.abstractmethod
def choose(self, choices, out=None, mode='raise') -> Any: ...
@abc.abstractmethod
def clip(self, a_min=None, a_max=None, out=None) -> Any: ...
@abc.abstractmethod
def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ...
@abc.abstractmethod
def conj(self) -> Any: ...
@abc.abstractmethod
def conjugate(self) -> Any: ...
@abc.abstractmethod
def copy(self) -> Any: ...
@abc.abstractmethod
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Any: ...
@abc.abstractmethod
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None) -> Any: ...
@abc.abstractmethod
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ...
@abc.abstractmethod
def dot(self, b, *, precision=None) -> Any: ...
@abc.abstractmethod
def flatten(self) -> Any: ...
@property
@abc.abstractmethod
def imag(self) -> Any: ...
@abc.abstractmethod
def item(self, *args) -> Any: ...
@abc.abstractmethod
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Any: ...
@abc.abstractmethod
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None,) -> Any: ...
@abc.abstractmethod
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Any: ...
@property
@abc.abstractmethod
def nbytes(self) -> Any: ...
@abc.abstractmethod
def nonzero(self, *, size=None, fill_value=None) -> Any: ...
@abc.abstractmethod
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Any: ...
@abc.abstractmethod
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False,) -> Any: ...
@abc.abstractmethod
def ravel(self, order='C') -> Any: ...
@property
@abc.abstractmethod
def real(self) -> Any: ...
@abc.abstractmethod
def repeat(self, repeats, axis: Optional[int] = None, *,
total_repeat_length=None) -> Any: ...
@abc.abstractmethod
def reshape(self, *args, order='C') -> Any: ...
@abc.abstractmethod
def round(self, decimals=0, out=None) -> Any: ...
@abc.abstractmethod
def searchsorted(self, v, side='left', sorter=None) -> Any: ...
@abc.abstractmethod
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
@abc.abstractmethod
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ...
@abc.abstractmethod
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
@abc.abstractmethod
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Any: ...
@abc.abstractmethod
def swapaxes(self, axis1: int, axis2: int) -> Any: ...
@abc.abstractmethod
def take(self, indices, axis: Optional[int] = None, out=None,
mode=None) -> Any: ...
@abc.abstractmethod
def tobytes(self, order='C') -> Any: ...
@abc.abstractmethod
def tolist(self) -> Any: ...
@abc.abstractmethod
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None) -> Any: ...
@abc.abstractmethod
def transpose(self, *args) -> Any: ...
@abc.abstractmethod
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
@abc.abstractmethod
def view(self, dtype=None, type=None) -> Any: ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self) -> Any: ...
# JAX extensions
@property
@abc.abstractmethod
def at(self) -> Any: ...
@property
@abc.abstractmethod
def aval(self) -> Any: ...
@property
@abc.abstractmethod
def weak_type(self) -> bool: ...
ndarray.register(device_array.DeviceArray)
for t in device_array.device_array_types:
ndarray.register(t)
ndarray.register(pxla._SDA_BASE_CLASS)

View File

@ -1,653 +0,0 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""
Implements ufuncs for jax.numpy.
"""
from functools import partial
import operator
from textwrap import dedent
import numpy as np
from jax._src.api import jit, custom_jvp
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.numpy.util import (
_check_arraylike, _promote_args, _promote_args_inexact,
_promote_shapes, _where, _wraps)
from jax import core
_INT_DTYPES = {
16: np.int16,
32: np.int32,
64: np.int64,
}
def _constant_like(x, const):
return np.array(const, dtype=dtypes.dtype(x))
def _result_dtype(op, *args):
"""Compute result dtype of applying op to arguments with given dtypes."""
args = [np.ones((0,) * np.ndim(arg), dtypes.dtype(arg)) for arg in args]
return dtypes.dtype(op(*args))
def _replace_inf(x):
return lax.select(isposinf(real(x)), lax._zeros(x), x)
def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
if promote_to_inexact:
fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
else:
fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
if promote_to_inexact:
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
else:
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _comparison_op(numpy_fn, lax_fn):
# TODO(https://github.com/google/jax/issues/6713): decorate this function with
# jit, after fixing a surprising interaction with remat(..., concrete=True).
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
rx = lax.real(x1)
ry = lax.real(x2)
return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
lax_fn(rx, ry))
return lax_fn(x1, x2)
return _wraps(numpy_fn)(fn)
def _logical_op(np_op, bitwise_op):
@_wraps(np_op, update_doc=False)
@partial(jit, inline=True)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x))
for x in args)
return bitwise_op(*_promote_args(np_op.__name__, *args))
return op
fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: x)
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
log = _one_to_one_unop(np.log, lax.log, True)
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
sin = _one_to_one_unop(np.sin, lax.sin, True)
cos = _one_to_one_unop(np.cos, lax.cos, True)
tan = _one_to_one_unop(np.tan, lax.tan, True)
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)
add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(np.minimum, lax.min)
maximum = _one_to_one_binop(np.maximum, lax.max)
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)
greater_equal = _comparison_op(np.greater_equal, lax.ge)
greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)
logical_and = _logical_op(np.logical_and, lax.bitwise_and)
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
@_wraps(np.arccosh)
@jit
def arccosh(x):
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*_promote_args_inexact("arccosh", x))
if dtypes.issubdtype(out.dtype, np.complexfloating):
out = _where(real(out) < 0, lax.neg(out), out)
return out
@_wraps(np.right_shift)
@partial(jit, inline=True)
def right_shift(x1, x2):
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@_wraps(np.absolute)
@partial(jit, inline=True)
def absolute(x):
_check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return x if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs)(absolute)
@_wraps(np.rint)
@jit
def rint(x):
_check_arraylike('rint', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.integer):
return lax.convert_element_type(x, dtypes.float_)
if dtypes.issubdtype(dtype, np.complexfloating):
return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
@_wraps(np.sign)
@jit
def sign(x):
_check_arraylike('sign', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
return lax.complex(
lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
return lax.sign(x)
@_wraps(np.copysign)
@jit
def copysign(x1, x2):
x1, x2 = _promote_args_inexact("copysign", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
raise TypeError("copysign does not support complex-valued inputs")
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
@_wraps(np.true_divide)
@partial(jit, inline=True)
def true_divide(x1, x2):
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)
divide = true_divide
@_wraps(np.floor_divide)
@jit
def floor_divide(x1, x2):
x1, x2 = _promote_args("floor_divide", x1, x2)
dtype = dtypes.dtype(x1)
if dtypes.issubdtype(dtype, np.integer):
quotient = lax.div(x1, x2)
select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
# TODO(mattjj): investigate why subtracting a scalar was causing promotion
return _where(select, quotient - 1, quotient)
elif dtypes.issubdtype(dtype, np.complexfloating):
x1r = lax.real(x1)
x1i = lax.imag(x1)
x2r = lax.real(x2)
x2i = lax.imag(x2)
which = lax.ge(lax.abs(x2r), lax.abs(x2i))
rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
rat2 = _where(which, lax.div(x2i, x2r), lax._const(x2i, 1))
out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
return lax.convert_element_type(out, dtype)
else:
return _float_divmod(x1, x2)[0]
@_wraps(np.divmod)
@jit
def divmod(x1, x2):
x1, x2 = _promote_args("divmod", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
return floor_divide(x1, x2), remainder(x1, x2)
else:
return _float_divmod(x1, x2)
def _float_divmod(x1, x2):
# see float_divmod in floatobject.c of CPython
mod = lax.rem(x1, x2)
div = lax.div(lax.sub(x1, mod), x2)
ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
mod = lax.select(ind, mod + x2, mod)
div = lax.select(ind, div - _constant_like(div, 1), div)
return lax.round(div), mod
@partial(jit, inline=True)
def _power(x1, x2):
x1, x2 = _promote_args("power", x1, x2)
dtype = dtypes.dtype(x1)
if not dtypes.issubdtype(dtype, np.integer):
return lax.pow(x1, x2)
# Integer power => use binary exponentiation.
# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
for _ in range(bits):
acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, one)
return acc
@_wraps(np.power)
def power(x1, x2):
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2)
except TypeError:
pass
else:
return lax.integer_pow(x1, x2)
return _power(x1, x2)
@custom_jvp
@_wraps(np.logaddexp)
@jit
def logaddexp(x1, x2):
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.log1p(lax.exp(delta)))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def _wrap_between(x, _a):
"""Wraps `x` between `[-a, a]`."""
a = _constant_like(x, _a)
two_a = _constant_like(x, 2 * _a)
zero = _constant_like(x, 0)
rem = lax.rem(lax.add(x, a), two_a)
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
return lax.sub(rem, a)
@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
@custom_jvp
@_wraps(np.logaddexp2)
@jit
def logaddexp2(x1, x2):
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
_constant_like(x1, np.log(2)))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
@logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
@_wraps(np.log2)
@partial(jit, inline=True)
def log2(x):
x, = _promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10)
@partial(jit, inline=True)
def log10(x):
x, = _promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2)
@partial(jit, inline=True)
def exp2(x):
x, = _promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
@_wraps(np.signbit)
@jit
def signbit(x):
x, = _promote_args("signbit", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.integer):
return lax.lt(x, _constant_like(x, 0))
elif dtypes.issubdtype(dtype, np.bool_):
return lax.full_like(x, False, dtype=np.bool_)
elif not dtypes.issubdtype(dtype, np.floating):
raise ValueError(
"jax.numpy.signbit is not well defined for %s" % dtype)
# TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
# F32.
if dtype == dtypes.bfloat16:
dtype = np.float32
x = lax.convert_element_type(x, np.float32)
info = dtypes.finfo(dtype)
if info.bits not in _INT_DTYPES:
raise NotImplementedError(
"jax.numpy.signbit only supports 16, 32, and 64-bit types.")
int_type = _INT_DTYPES[info.bits]
x = lax.bitcast_convert_type(x, int_type)
return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def _normalize_float(x):
info = dtypes.finfo(dtypes.dtype(x))
cond = lax.abs(x) < info.tiny
x1 = _where(cond, x * lax._const(x, 1 << info.nmant), x)
x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32))
int_type = _INT_DTYPES[info.bits]
return lax.bitcast_convert_type(x1, int_type), x2
@_wraps(np.ldexp)
@jit
def ldexp(x1, x2):
_check_arraylike("ldexp", x1, x2)
dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
x1, x2 = _promote_shapes("ldexp", x1, x2)
x1 = lax.convert_element_type(x1, dtype)
info = dtypes.finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
int_type = _INT_DTYPES[info.bits]
x, e = _normalize_float(x1)
x2 += e + ((x >> info.nmant) & mask) - bias
# find underflow/overflow before denormalization
underflow_cond = x2 < -(bias + info.nmant)
overflow_cond = x2 > bias
m = lax.full_like(x, 1, dtype=dtype)
# denormals
cond = x2 < -bias + 1
x2 = _where(cond, x2 + info.nmant, x2)
m = _where(cond, m / (1 << info.nmant), m)
x2 = lax.convert_element_type(x2, np.int32)
x &= ~(mask << info.nmant)
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)
x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
# underflow
x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
# overflow
x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
@_wraps(np.frexp)
@jit
def frexp(x):
_check_arraylike("frexp", x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
raise TypeError("frexp does not support complex-valued inputs")
elif not dtypes.issubdtype(dtypes.dtype(x), np.floating):
x = lax.convert_element_type(x, np.float_)
dtype = dtypes.dtype(x)
info = dtypes.finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
x1, x2 = _normalize_float(x)
x2 += ((x1 >> info.nmant) & mask) - bias + 1
x1 &= ~(mask << info.nmant)
x1 |= (bias - 1) << info.nmant
x1 = lax.bitcast_convert_type(x1, dtype)
cond = isinf(x) | isnan(x) | (x == 0)
x2 = _where(cond, lax._zeros(x2), x2)
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
@_wraps(np.remainder)
@jit
def remainder(x1, x2):
x1, x2 = _promote_args("remainder", x1, x2)
zero = _constant_like(x1, 0)
trunc_mod = lax.rem(x1, x2)
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
do_plus = lax.bitwise_and(
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod)(remainder)
@_wraps(np.fmod)
@jit
def fmod(x1, x2):
_check_arraylike("fmod", x1, x2)
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
x2 = _where(x2 == 0, lax._ones(x2), x2)
return lax.rem(*_promote_args("fmod", x1, x2))
@_wraps(np.square)
@partial(jit, inline=True)
def square(x):
_check_arraylike("square", x)
return lax.integer_pow(x, 2)
@_wraps(np.deg2rad)
@partial(jit, inline=True)
def deg2rad(x):
x, = _promote_args_inexact("deg2rad", x)
return lax.mul(x, lax._const(x, np.pi / 180))
@_wraps(np.rad2deg)
@partial(jit, inline=True)
def rad2deg(x):
x, = _promote_args_inexact("rad2deg", x)
return lax.mul(x, lax._const(x, 180 / np.pi))
degrees = rad2deg
radians = deg2rad
@_wraps(np.conjugate)
@partial(jit, inline=True)
def conjugate(x):
_check_arraylike("conjugate", x)
return lax.conj(x) if np.iscomplexobj(x) else x
conj = conjugate
@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val):
_check_arraylike("imag", val)
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@_wraps(np.real)
@partial(jit, inline=True)
def real(val):
_check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else val
@_wraps(np.modf, skip_params=['out'])
@jit
def modf(x, out=None):
_check_arraylike("modf", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x))
return x - whole, whole
@_wraps(np.isfinite)
@jit
def isfinite(x):
_check_arraylike("isfinite", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.is_finite(x)
elif dtypes.issubdtype(dtype, np.complexfloating):
return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
else:
return lax.full_like(x, True, dtype=np.bool_)
@_wraps(np.isinf)
@jit
def isinf(x):
_check_arraylike("isinf", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.eq(lax.abs(x), _constant_like(x, np.inf))
elif dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
im = lax.imag(x)
return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
lax.eq(lax.abs(im), _constant_like(im, np.inf)))
else:
return lax.full_like(x, False, dtype=np.bool_)
def _isposneginf(infinity, x, out):
if out is not None:
raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.eq(x, _constant_like(x, infinity))
elif dtypes.issubdtype(dtype, np.complexfloating):
raise ValueError("isposinf/isneginf are not well defined for complex types")
else:
return lax.full_like(x, False, dtype=np.bool_)
isposinf = _wraps(np.isposinf, skip_params=['out'])(
lambda x, out=None: _isposneginf(np.inf, x, out)
)
isneginf = _wraps(np.isneginf, skip_params=['out'])(
lambda x, out=None: _isposneginf(-np.inf, x, out)
)
@_wraps(np.isnan)
@jit
def isnan(x):
_check_arraylike("isnan", x)
return lax.ne(x, x)

View File

@ -12,21 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import re
import textwrap
from typing import Callable, NamedTuple, Optional, Dict, Sequence
import warnings
from jax._src.config import config
from jax._src import dtypes
from jax._src.numpy.ndarray import ndarray
from jax._src.util import safe_zip
from jax._src import api
from jax import core
from jax import lax
import numpy as np
_parameter_break = re.compile("\n(?=[A-Za-z_])")
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
@ -188,159 +178,3 @@ def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: st
setattr(op, attr, value)
return op
return wrap
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args]
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
def _arraylike(x):
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or dtypes.is_python_scalar(x) or np.isscalar(x))
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not _arraylike(arg))
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def _check_no_float0s(fun_name, *args):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
raise TypeError(
f"Called {fun_name} with a float0 array. "
"float0s do not support any operations by design because they "
"are not compatible with non-trivial vector spaces. No implicit dtype "
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
def _promote_args(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
def _promote_args_inexact(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
def _promote_shapes(fun_name, *args):
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return args
else:
shapes = [np.shape(arg) for arg in args]
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return args # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return args
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
res_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
result_rank = len(lax.broadcast_shapes(*shapes))
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _broadcast_to(arr, shape):
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
arr = arr if isinstance(arr, ndarray) else api.device_put(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
shape = (shape,)
shape = core.canonicalize_shape(shape) # check that shape is concrete
arr_shape = np.shape(arr)
if core.symbolic_equal_shape(arr_shape, shape):
return arr
else:
nlead = len(shape) - len(arr_shape)
shape_tail = shape[nlead:]
compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
@partial(api.jit, inline=True)
def _broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
return [arg if isinstance(arg, ndarray) or np.isscalar(arg)
else api.device_put(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
# The `jit` on `where` exists to avoid materializing constants in cases like
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition, x=None, y=None):
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not dtypes.issubdtype(dtypes.dtype(condition), np.bool_):
condition = lax.ne(condition, lax._const(condition, 0))
x, y = _promote_dtypes(x, y)
condition, x, y = _broadcast_arrays(condition, x, y)
try: is_always_empty = core.is_empty_shape(np.shape(x))
except: is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x, y) if not is_always_empty else x