mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Rollback of:
d09d7b8d1363eab1c14051eb2376e605366537f9 by Jake VanderPlas <jakevdp@google.com>: Factor-out pieces of lax_numpy.py PiperOrigin-RevId: 431833044
This commit is contained in:
parent
4755dc3fee
commit
3766dd2120
File diff suppressed because it is too large
Load Diff
@ -1,295 +0,0 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ndarray is defined as an virtual abstract base class.
|
||||
|
||||
import abc
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import device_array
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ArrayMeta(abc.ABCMeta):
|
||||
"""Metaclass for overriding ndarray isinstance checks."""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
# Allow tracer instances with avals that are instances of UnshapedArray.
|
||||
# We could instead just declare Tracer an instance of the ndarray type, but
|
||||
# there can be traced values that are not arrays. The main downside here is
|
||||
# that isinstance(x, ndarray) might return true but
|
||||
# issubclass(type(x), ndarray) might return false for an array tracer.
|
||||
try:
|
||||
return (hasattr(instance, "aval") and
|
||||
isinstance(instance.aval, core.UnshapedArray))
|
||||
except AttributeError:
|
||||
super().__instancecheck__(instance)
|
||||
|
||||
|
||||
class ndarray(metaclass=ArrayMeta):
|
||||
dtype: np.dtype
|
||||
ndim: int
|
||||
shape: Tuple[int, ...]
|
||||
size: int
|
||||
|
||||
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
|
||||
order=None):
|
||||
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
|
||||
" Use jax.numpy.array, or jax.numpy.zeros instead.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, key, indices_are_sorted=False,
|
||||
unique_indices=False) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __setitem__(self, key, value) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __reversed__(self) -> Any: ...
|
||||
|
||||
# Comparisons
|
||||
@abc.abstractmethod
|
||||
def __lt__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __le__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __ne__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __gt__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __ge__(self, other) -> Any: ...
|
||||
|
||||
# Unary arithmetic
|
||||
|
||||
@abc.abstractmethod
|
||||
def __neg__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __pos__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __abs__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __invert__(self) -> Any: ...
|
||||
|
||||
# Binary arithmetic
|
||||
|
||||
@abc.abstractmethod
|
||||
def __add__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __sub__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __mul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __matmul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __truediv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __floordiv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __mod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __divmod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __pow__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __lshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __and__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __xor__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __or__(self, other) -> Any: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __radd__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rsub__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rmul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rmatmul__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rtruediv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rfloordiv__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rmod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rdivmod__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rpow__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rlshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rrshift__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rand__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __rxor__(self, other) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __ror__(self, other) -> Any: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __bool__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __complex__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __int__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __float__(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def __round__(self, ndigits=None) -> Any: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __index__(self) -> Any: ...
|
||||
|
||||
# np.ndarray methods:
|
||||
@abc.abstractmethod
|
||||
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def astype(self, dtype) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def choose(self, choices, out=None, mode='raise') -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def clip(self, a_min=None, a_max=None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def conj(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def conjugate(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def copy(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def dot(self, b, *, precision=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def flatten(self) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def imag(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def item(self, *args) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None,) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def nbytes(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def nonzero(self, *, size=None, fill_value=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False,) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def ravel(self, order='C') -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def real(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def repeat(self, repeats, axis: Optional[int] = None, *,
|
||||
total_repeat_length=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def reshape(self, *args, order='C') -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def round(self, decimals=0, out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def swapaxes(self, axis1: int, axis2: int) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def take(self, indices, axis: Optional[int] = None, out=None,
|
||||
mode=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def tobytes(self, order='C') -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def tolist(self) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
|
||||
out=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def transpose(self, *args) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
|
||||
@abc.abstractmethod
|
||||
def view(self, dtype=None, type=None) -> Any: ...
|
||||
|
||||
# Even though we don't always support the NumPy array protocol, e.g., for
|
||||
# tracer types, for type checking purposes we must declare support so we
|
||||
# implement the NumPy ArrayLike protocol.
|
||||
def __array__(self) -> Any: ...
|
||||
|
||||
# JAX extensions
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def at(self) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def aval(self) -> Any: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def weak_type(self) -> bool: ...
|
||||
|
||||
|
||||
ndarray.register(device_array.DeviceArray)
|
||||
for t in device_array.device_array_types:
|
||||
ndarray.register(t)
|
||||
ndarray.register(pxla._SDA_BASE_CLASS)
|
@ -1,653 +0,0 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pytype: skip-file
|
||||
"""
|
||||
Implements ufuncs for jax.numpy.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from textwrap import dedent
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src.api import jit, custom_jvp
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.numpy.util import (
|
||||
_check_arraylike, _promote_args, _promote_args_inexact,
|
||||
_promote_shapes, _where, _wraps)
|
||||
from jax import core
|
||||
|
||||
|
||||
_INT_DTYPES = {
|
||||
16: np.int16,
|
||||
32: np.int32,
|
||||
64: np.int64,
|
||||
}
|
||||
|
||||
|
||||
def _constant_like(x, const):
|
||||
return np.array(const, dtype=dtypes.dtype(x))
|
||||
|
||||
|
||||
def _result_dtype(op, *args):
|
||||
"""Compute result dtype of applying op to arguments with given dtypes."""
|
||||
args = [np.ones((0,) * np.ndim(arg), dtypes.dtype(arg)) for arg in args]
|
||||
return dtypes.dtype(op(*args))
|
||||
|
||||
|
||||
def _replace_inf(x):
|
||||
return lax.select(isposinf(real(x)), lax._zeros(x), x)
|
||||
|
||||
|
||||
def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
|
||||
if promote_to_inexact:
|
||||
fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
|
||||
else:
|
||||
fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
|
||||
return _wraps(numpy_fn, lax_description=doc)(fn)
|
||||
else:
|
||||
return _wraps(numpy_fn)(fn)
|
||||
|
||||
|
||||
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
|
||||
if promote_to_inexact:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
|
||||
else:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
|
||||
return _wraps(numpy_fn, lax_description=doc)(fn)
|
||||
else:
|
||||
return _wraps(numpy_fn)(fn)
|
||||
|
||||
|
||||
def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
|
||||
def fn(x1, x2):
|
||||
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
|
||||
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
|
||||
return _wraps(numpy_fn, lax_description=doc)(fn)
|
||||
else:
|
||||
return _wraps(numpy_fn)(fn)
|
||||
|
||||
|
||||
def _comparison_op(numpy_fn, lax_fn):
|
||||
# TODO(https://github.com/google/jax/issues/6713): decorate this function with
|
||||
# jit, after fixing a surprising interaction with remat(..., concrete=True).
|
||||
def fn(x1, x2):
|
||||
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
|
||||
# Comparison on complex types are defined as a lexicographic ordering on
|
||||
# the (real, imag) pair.
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
|
||||
rx = lax.real(x1)
|
||||
ry = lax.real(x2)
|
||||
return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
|
||||
lax_fn(rx, ry))
|
||||
return lax_fn(x1, x2)
|
||||
return _wraps(numpy_fn)(fn)
|
||||
|
||||
|
||||
def _logical_op(np_op, bitwise_op):
|
||||
@_wraps(np_op, update_doc=False)
|
||||
@partial(jit, inline=True)
|
||||
def op(*args):
|
||||
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
|
||||
args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x))
|
||||
for x in args)
|
||||
return bitwise_op(*_promote_args(np_op.__name__, *args))
|
||||
return op
|
||||
|
||||
|
||||
fabs = _one_to_one_unop(np.fabs, lax.abs, True)
|
||||
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
|
||||
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
|
||||
negative = _one_to_one_unop(np.negative, lax.neg)
|
||||
positive = _one_to_one_unop(np.positive, lambda x: x)
|
||||
floor = _one_to_one_unop(np.floor, lax.floor, True)
|
||||
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
|
||||
exp = _one_to_one_unop(np.exp, lax.exp, True)
|
||||
log = _one_to_one_unop(np.log, lax.log, True)
|
||||
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
|
||||
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
|
||||
sin = _one_to_one_unop(np.sin, lax.sin, True)
|
||||
cos = _one_to_one_unop(np.cos, lax.cos, True)
|
||||
tan = _one_to_one_unop(np.tan, lax.tan, True)
|
||||
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
|
||||
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
|
||||
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
|
||||
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
|
||||
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
|
||||
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
|
||||
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
|
||||
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
|
||||
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
|
||||
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)
|
||||
|
||||
add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
|
||||
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
|
||||
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
|
||||
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
|
||||
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left)
|
||||
equal = _one_to_one_binop(np.equal, lax.eq)
|
||||
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
|
||||
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
|
||||
subtract = _one_to_one_binop(np.subtract, lax.sub)
|
||||
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
|
||||
minimum = _one_to_one_binop(np.minimum, lax.min)
|
||||
maximum = _one_to_one_binop(np.maximum, lax.max)
|
||||
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
|
||||
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)
|
||||
|
||||
greater_equal = _comparison_op(np.greater_equal, lax.ge)
|
||||
greater = _comparison_op(np.greater, lax.gt)
|
||||
less_equal = _comparison_op(np.less_equal, lax.le)
|
||||
less = _comparison_op(np.less, lax.lt)
|
||||
|
||||
logical_and = _logical_op(np.logical_and, lax.bitwise_and)
|
||||
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
|
||||
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
|
||||
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
|
||||
|
||||
|
||||
@_wraps(np.arccosh)
|
||||
@jit
|
||||
def arccosh(x):
|
||||
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
|
||||
# convention than np.arccosh.
|
||||
out = lax.acosh(*_promote_args_inexact("arccosh", x))
|
||||
if dtypes.issubdtype(out.dtype, np.complexfloating):
|
||||
out = _where(real(out) < 0, lax.neg(out), out)
|
||||
return out
|
||||
|
||||
|
||||
@_wraps(np.right_shift)
|
||||
@partial(jit, inline=True)
|
||||
def right_shift(x1, x2):
|
||||
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
|
||||
lax_fn = lax.shift_right_logical if \
|
||||
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
|
||||
return lax_fn(x1, x2)
|
||||
|
||||
|
||||
@_wraps(np.absolute)
|
||||
@partial(jit, inline=True)
|
||||
def absolute(x):
|
||||
_check_arraylike('absolute', x)
|
||||
dt = dtypes.dtype(x)
|
||||
return x if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
|
||||
abs = _wraps(np.abs)(absolute)
|
||||
|
||||
|
||||
@_wraps(np.rint)
|
||||
@jit
|
||||
def rint(x):
|
||||
_check_arraylike('rint', x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
return lax.convert_element_type(x, dtypes.float_)
|
||||
if dtypes.issubdtype(dtype, np.complexfloating):
|
||||
return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
|
||||
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
|
||||
|
||||
|
||||
@_wraps(np.sign)
|
||||
@jit
|
||||
def sign(x):
|
||||
_check_arraylike('sign', x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.complexfloating):
|
||||
re = lax.real(x)
|
||||
return lax.complex(
|
||||
lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
|
||||
return lax.sign(x)
|
||||
|
||||
|
||||
@_wraps(np.copysign)
|
||||
@jit
|
||||
def copysign(x1, x2):
|
||||
x1, x2 = _promote_args_inexact("copysign", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
|
||||
raise TypeError("copysign does not support complex-valued inputs")
|
||||
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
|
||||
|
||||
|
||||
@_wraps(np.true_divide)
|
||||
@partial(jit, inline=True)
|
||||
def true_divide(x1, x2):
|
||||
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
|
||||
return lax.div(x1, x2)
|
||||
|
||||
divide = true_divide
|
||||
|
||||
|
||||
@_wraps(np.floor_divide)
|
||||
@jit
|
||||
def floor_divide(x1, x2):
|
||||
x1, x2 = _promote_args("floor_divide", x1, x2)
|
||||
dtype = dtypes.dtype(x1)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
quotient = lax.div(x1, x2)
|
||||
select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
|
||||
# TODO(mattjj): investigate why subtracting a scalar was causing promotion
|
||||
return _where(select, quotient - 1, quotient)
|
||||
elif dtypes.issubdtype(dtype, np.complexfloating):
|
||||
x1r = lax.real(x1)
|
||||
x1i = lax.imag(x1)
|
||||
x2r = lax.real(x2)
|
||||
x2i = lax.imag(x2)
|
||||
which = lax.ge(lax.abs(x2r), lax.abs(x2i))
|
||||
rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
|
||||
rat2 = _where(which, lax.div(x2i, x2r), lax._const(x2i, 1))
|
||||
out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
|
||||
lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
|
||||
return lax.convert_element_type(out, dtype)
|
||||
else:
|
||||
return _float_divmod(x1, x2)[0]
|
||||
|
||||
|
||||
@_wraps(np.divmod)
|
||||
@jit
|
||||
def divmod(x1, x2):
|
||||
x1, x2 = _promote_args("divmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
|
||||
return floor_divide(x1, x2), remainder(x1, x2)
|
||||
else:
|
||||
return _float_divmod(x1, x2)
|
||||
|
||||
|
||||
def _float_divmod(x1, x2):
|
||||
# see float_divmod in floatobject.c of CPython
|
||||
mod = lax.rem(x1, x2)
|
||||
div = lax.div(lax.sub(x1, mod), x2)
|
||||
|
||||
ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
|
||||
mod = lax.select(ind, mod + x2, mod)
|
||||
div = lax.select(ind, div - _constant_like(div, 1), div)
|
||||
|
||||
return lax.round(div), mod
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def _power(x1, x2):
|
||||
x1, x2 = _promote_args("power", x1, x2)
|
||||
dtype = dtypes.dtype(x1)
|
||||
if not dtypes.issubdtype(dtype, np.integer):
|
||||
return lax.pow(x1, x2)
|
||||
|
||||
# Integer power => use binary exponentiation.
|
||||
|
||||
# TODO(phawkins): add integer pow support to XLA.
|
||||
bits = 6 # Anything more would overflow for any x1 > 1
|
||||
zero = _constant_like(x2, 0)
|
||||
one = _constant_like(x2, 1)
|
||||
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
|
||||
acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
|
||||
for _ in range(bits):
|
||||
acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
|
||||
x1 = lax.mul(x1, x1)
|
||||
x2 = lax.shift_right_logical(x2, one)
|
||||
return acc
|
||||
|
||||
|
||||
@_wraps(np.power)
|
||||
def power(x1, x2):
|
||||
# Special case for concrete integer scalars: use binary exponentiation.
|
||||
# Using lax.pow may be imprecise for floating-point values; the goal of this
|
||||
# code path is to make sure we end up with a precise output for the common
|
||||
# pattern ``x ** 2`` or similar.
|
||||
if isinstance(core.get_aval(x2), core.ConcreteArray):
|
||||
try:
|
||||
x2 = operator.index(x2)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
return lax.integer_pow(x1, x2)
|
||||
return _power(x1, x2)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp)
|
||||
@jit
|
||||
def logaddexp(x1, x2):
|
||||
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(lax._isnan(delta),
|
||||
lax.add(x1, x2), # NaNs or infinities of the same sign.
|
||||
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
|
||||
else:
|
||||
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
|
||||
out = lax.add(amax, lax.log1p(lax.exp(delta)))
|
||||
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
|
||||
|
||||
|
||||
def _wrap_between(x, _a):
|
||||
"""Wraps `x` between `[-a, a]`."""
|
||||
a = _constant_like(x, _a)
|
||||
two_a = _constant_like(x, 2 * _a)
|
||||
zero = _constant_like(x, 0)
|
||||
rem = lax.rem(lax.add(x, a), two_a)
|
||||
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
|
||||
return lax.sub(rem, a)
|
||||
|
||||
|
||||
@logaddexp.defjvp
|
||||
def _logaddexp_jvp(primals, tangents):
|
||||
x1, x2 = primals
|
||||
t1, t2 = tangents
|
||||
x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
|
||||
primal_out = logaddexp(x1, x2)
|
||||
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
|
||||
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
|
||||
return primal_out, tangent_out
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp2)
|
||||
@jit
|
||||
def logaddexp2(x1, x2):
|
||||
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(lax._isnan(delta),
|
||||
lax.add(x1, x2), # NaNs or infinities of the same sign.
|
||||
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
|
||||
_constant_like(x1, np.log(2)))))
|
||||
else:
|
||||
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
|
||||
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
|
||||
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
|
||||
|
||||
|
||||
@logaddexp2.defjvp
|
||||
def _logaddexp2_jvp(primals, tangents):
|
||||
x1, x2 = primals
|
||||
t1, t2 = tangents
|
||||
x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
|
||||
primal_out = logaddexp2(x1, x2)
|
||||
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
|
||||
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
|
||||
return primal_out, tangent_out
|
||||
|
||||
|
||||
@_wraps(np.log2)
|
||||
@partial(jit, inline=True)
|
||||
def log2(x):
|
||||
x, = _promote_args_inexact("log2", x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
|
||||
|
||||
|
||||
@_wraps(np.log10)
|
||||
@partial(jit, inline=True)
|
||||
def log10(x):
|
||||
x, = _promote_args_inexact("log10", x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
@_wraps(np.exp2)
|
||||
@partial(jit, inline=True)
|
||||
def exp2(x):
|
||||
x, = _promote_args_inexact("exp2", x)
|
||||
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
|
||||
|
||||
|
||||
@_wraps(np.signbit)
|
||||
@jit
|
||||
def signbit(x):
|
||||
x, = _promote_args("signbit", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
return lax.lt(x, _constant_like(x, 0))
|
||||
elif dtypes.issubdtype(dtype, np.bool_):
|
||||
return lax.full_like(x, False, dtype=np.bool_)
|
||||
elif not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(
|
||||
"jax.numpy.signbit is not well defined for %s" % dtype)
|
||||
|
||||
# TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
|
||||
# F32.
|
||||
if dtype == dtypes.bfloat16:
|
||||
dtype = np.float32
|
||||
x = lax.convert_element_type(x, np.float32)
|
||||
|
||||
info = dtypes.finfo(dtype)
|
||||
if info.bits not in _INT_DTYPES:
|
||||
raise NotImplementedError(
|
||||
"jax.numpy.signbit only supports 16, 32, and 64-bit types.")
|
||||
int_type = _INT_DTYPES[info.bits]
|
||||
x = lax.bitcast_convert_type(x, int_type)
|
||||
return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
|
||||
|
||||
|
||||
def _normalize_float(x):
|
||||
info = dtypes.finfo(dtypes.dtype(x))
|
||||
cond = lax.abs(x) < info.tiny
|
||||
x1 = _where(cond, x * lax._const(x, 1 << info.nmant), x)
|
||||
x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32))
|
||||
int_type = _INT_DTYPES[info.bits]
|
||||
return lax.bitcast_convert_type(x1, int_type), x2
|
||||
|
||||
|
||||
@_wraps(np.ldexp)
|
||||
@jit
|
||||
def ldexp(x1, x2):
|
||||
_check_arraylike("ldexp", x1, x2)
|
||||
dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
|
||||
x1, x2 = _promote_shapes("ldexp", x1, x2)
|
||||
x1 = lax.convert_element_type(x1, dtype)
|
||||
|
||||
info = dtypes.finfo(dtype)
|
||||
mask = (1 << info.nexp) - 1
|
||||
bias = ((1 << info.nexp) - 1) >> 1
|
||||
|
||||
int_type = _INT_DTYPES[info.bits]
|
||||
|
||||
x, e = _normalize_float(x1)
|
||||
x2 += e + ((x >> info.nmant) & mask) - bias
|
||||
|
||||
# find underflow/overflow before denormalization
|
||||
underflow_cond = x2 < -(bias + info.nmant)
|
||||
overflow_cond = x2 > bias
|
||||
|
||||
m = lax.full_like(x, 1, dtype=dtype)
|
||||
|
||||
# denormals
|
||||
cond = x2 < -bias + 1
|
||||
x2 = _where(cond, x2 + info.nmant, x2)
|
||||
m = _where(cond, m / (1 << info.nmant), m)
|
||||
|
||||
x2 = lax.convert_element_type(x2, np.int32)
|
||||
x &= ~(mask << info.nmant)
|
||||
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)
|
||||
|
||||
x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
|
||||
|
||||
# underflow
|
||||
x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
|
||||
# overflow
|
||||
x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
|
||||
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
|
||||
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
|
||||
|
||||
|
||||
@_wraps(np.frexp)
|
||||
@jit
|
||||
def frexp(x):
|
||||
_check_arraylike("frexp", x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
raise TypeError("frexp does not support complex-valued inputs")
|
||||
elif not dtypes.issubdtype(dtypes.dtype(x), np.floating):
|
||||
x = lax.convert_element_type(x, np.float_)
|
||||
|
||||
dtype = dtypes.dtype(x)
|
||||
info = dtypes.finfo(dtype)
|
||||
mask = (1 << info.nexp) - 1
|
||||
bias = ((1 << info.nexp) - 1) >> 1
|
||||
|
||||
x1, x2 = _normalize_float(x)
|
||||
x2 += ((x1 >> info.nmant) & mask) - bias + 1
|
||||
x1 &= ~(mask << info.nmant)
|
||||
x1 |= (bias - 1) << info.nmant
|
||||
x1 = lax.bitcast_convert_type(x1, dtype)
|
||||
|
||||
cond = isinf(x) | isnan(x) | (x == 0)
|
||||
x2 = _where(cond, lax._zeros(x2), x2)
|
||||
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
|
||||
|
||||
|
||||
@_wraps(np.remainder)
|
||||
@jit
|
||||
def remainder(x1, x2):
|
||||
x1, x2 = _promote_args("remainder", x1, x2)
|
||||
zero = _constant_like(x1, 0)
|
||||
trunc_mod = lax.rem(x1, x2)
|
||||
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
|
||||
do_plus = lax.bitwise_and(
|
||||
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
|
||||
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
|
||||
mod = _wraps(np.mod)(remainder)
|
||||
|
||||
|
||||
@_wraps(np.fmod)
|
||||
@jit
|
||||
def fmod(x1, x2):
|
||||
_check_arraylike("fmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
|
||||
x2 = _where(x2 == 0, lax._ones(x2), x2)
|
||||
return lax.rem(*_promote_args("fmod", x1, x2))
|
||||
|
||||
|
||||
@_wraps(np.square)
|
||||
@partial(jit, inline=True)
|
||||
def square(x):
|
||||
_check_arraylike("square", x)
|
||||
return lax.integer_pow(x, 2)
|
||||
|
||||
|
||||
@_wraps(np.deg2rad)
|
||||
@partial(jit, inline=True)
|
||||
def deg2rad(x):
|
||||
x, = _promote_args_inexact("deg2rad", x)
|
||||
return lax.mul(x, lax._const(x, np.pi / 180))
|
||||
|
||||
|
||||
@_wraps(np.rad2deg)
|
||||
@partial(jit, inline=True)
|
||||
def rad2deg(x):
|
||||
x, = _promote_args_inexact("rad2deg", x)
|
||||
return lax.mul(x, lax._const(x, 180 / np.pi))
|
||||
|
||||
|
||||
degrees = rad2deg
|
||||
radians = deg2rad
|
||||
|
||||
|
||||
@_wraps(np.conjugate)
|
||||
@partial(jit, inline=True)
|
||||
def conjugate(x):
|
||||
_check_arraylike("conjugate", x)
|
||||
return lax.conj(x) if np.iscomplexobj(x) else x
|
||||
conj = conjugate
|
||||
|
||||
|
||||
@_wraps(np.imag)
|
||||
@partial(jit, inline=True)
|
||||
def imag(val):
|
||||
_check_arraylike("imag", val)
|
||||
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
|
||||
|
||||
|
||||
@_wraps(np.real)
|
||||
@partial(jit, inline=True)
|
||||
def real(val):
|
||||
_check_arraylike("real", val)
|
||||
return lax.real(val) if np.iscomplexobj(val) else val
|
||||
|
||||
@_wraps(np.modf, skip_params=['out'])
|
||||
@jit
|
||||
def modf(x, out=None):
|
||||
_check_arraylike("modf", x)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
|
||||
whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x))
|
||||
return x - whole, whole
|
||||
|
||||
|
||||
@_wraps(np.isfinite)
|
||||
@jit
|
||||
def isfinite(x):
|
||||
_check_arraylike("isfinite", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.floating):
|
||||
return lax.is_finite(x)
|
||||
elif dtypes.issubdtype(dtype, np.complexfloating):
|
||||
return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
|
||||
else:
|
||||
return lax.full_like(x, True, dtype=np.bool_)
|
||||
|
||||
|
||||
@_wraps(np.isinf)
|
||||
@jit
|
||||
def isinf(x):
|
||||
_check_arraylike("isinf", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.floating):
|
||||
return lax.eq(lax.abs(x), _constant_like(x, np.inf))
|
||||
elif dtypes.issubdtype(dtype, np.complexfloating):
|
||||
re = lax.real(x)
|
||||
im = lax.imag(x)
|
||||
return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
|
||||
lax.eq(lax.abs(im), _constant_like(im, np.inf)))
|
||||
else:
|
||||
return lax.full_like(x, False, dtype=np.bool_)
|
||||
|
||||
|
||||
def _isposneginf(infinity, x, out):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.floating):
|
||||
return lax.eq(x, _constant_like(x, infinity))
|
||||
elif dtypes.issubdtype(dtype, np.complexfloating):
|
||||
raise ValueError("isposinf/isneginf are not well defined for complex types")
|
||||
else:
|
||||
return lax.full_like(x, False, dtype=np.bool_)
|
||||
|
||||
|
||||
isposinf = _wraps(np.isposinf, skip_params=['out'])(
|
||||
lambda x, out=None: _isposneginf(np.inf, x, out)
|
||||
)
|
||||
|
||||
|
||||
isneginf = _wraps(np.isneginf, skip_params=['out'])(
|
||||
lambda x, out=None: _isposneginf(-np.inf, x, out)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(np.isnan)
|
||||
@jit
|
||||
def isnan(x):
|
||||
_check_arraylike("isnan", x)
|
||||
return lax.ne(x, x)
|
@ -12,21 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Callable, NamedTuple, Optional, Dict, Sequence
|
||||
import warnings
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax import lax
|
||||
|
||||
import numpy as np
|
||||
|
||||
_parameter_break = re.compile("\n(?=[A-Za-z_])")
|
||||
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
|
||||
@ -188,159 +178,3 @@ def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: st
|
||||
setattr(op, attr, value)
|
||||
return op
|
||||
return wrap
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
def _promote_dtypes_inexact(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = _to_inexact_dtype(to_dtype)
|
||||
weak_type = (weak_type and to_dtype == to_dtype_inexact)
|
||||
return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args]
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
|
||||
|
||||
def _arraylike(x):
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
||||
hasattr(x, '__jax_array__') or dtypes.is_python_scalar(x) or np.isscalar(x))
|
||||
|
||||
def _check_arraylike(fun_name, *args):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
||||
if any(not _arraylike(arg) for arg in args):
|
||||
pos, arg = next((i, arg) for i, arg in enumerate(args)
|
||||
if not _arraylike(arg))
|
||||
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
|
||||
raise TypeError(msg.format(fun_name, type(arg), pos))
|
||||
|
||||
def _check_no_float0s(fun_name, *args):
|
||||
"""Check if none of the args have dtype float0."""
|
||||
if any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
|
||||
raise TypeError(
|
||||
f"Called {fun_name} with a float0 array. "
|
||||
"float0s do not support any operations by design because they "
|
||||
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
||||
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
||||
"to cast a float0 array to a regular zeros array. \n"
|
||||
"If you didn't expect to get a float0 you might have accidentally "
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
|
||||
def _promote_args(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
||||
|
||||
def _promote_args_inexact(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
Promotes non-inexact types to an inexact type."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
|
||||
|
||||
def _promote_shapes(fun_name, *args):
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
|
||||
return args # no need for rank promotion, so rely on lax promotion
|
||||
nonscalar_ranks = {len(shp) for shp in shapes if shp}
|
||||
if len(nonscalar_ranks) < 2:
|
||||
return args
|
||||
else:
|
||||
if config.jax_numpy_rank_promotion != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
if config.jax_dynamic_shapes:
|
||||
# With dynamic shapes we don't support singleton-dimension broadcasting;
|
||||
# we instead broadcast out to the full shape as a temporary workaround.
|
||||
res_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
|
||||
else:
|
||||
result_rank = len(lax.broadcast_shapes(*shapes))
|
||||
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||
for arg, shp in zip(args, shapes)]
|
||||
|
||||
def _rank_promotion_warning_or_error(fun_name, shapes):
|
||||
if config.jax_numpy_rank_promotion == "warn":
|
||||
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
||||
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
||||
"disable this warning; for more information, see "
|
||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
elif config.jax_numpy_rank_promotion == "raise":
|
||||
msg = ("Operands could not be broadcast together for {} on shapes {} "
|
||||
"and with the config option jax_numpy_rank_promotion='raise'. "
|
||||
"For more information, see "
|
||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
|
||||
|
||||
def _broadcast_to(arr, shape):
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape)
|
||||
arr = arr if isinstance(arr, ndarray) else api.device_put(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
shape = core.canonicalize_shape(shape) # check that shape is concrete
|
||||
arr_shape = np.shape(arr)
|
||||
if core.symbolic_equal_shape(arr_shape, shape):
|
||||
return arr
|
||||
else:
|
||||
nlead = len(shape) - len(arr_shape)
|
||||
shape_tail = shape[nlead:]
|
||||
compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
|
||||
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
|
||||
if nlead < 0 or not compatible:
|
||||
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
||||
raise ValueError(msg.format(arr_shape, shape))
|
||||
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
|
||||
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
|
||||
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
|
||||
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
|
||||
return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
|
||||
|
||||
|
||||
@partial(api.jit, inline=True)
|
||||
def _broadcast_arrays(*args):
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
# TODO(mattjj): remove the array(arg) here
|
||||
return [arg if isinstance(arg, ndarray) or np.isscalar(arg)
|
||||
else api.device_put(arg) for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, result_shape) for arg in args]
|
||||
|
||||
|
||||
# The `jit` on `where` exists to avoid materializing constants in cases like
|
||||
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
|
||||
# materialize the broadcast forms of scalar arguments.
|
||||
@api.jit
|
||||
def _where(condition, x=None, y=None):
|
||||
if x is None or y is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments should "
|
||||
"be provided to jax.numpy.where, got {} and {}."
|
||||
.format(x, y))
|
||||
if not dtypes.issubdtype(dtypes.dtype(condition), np.bool_):
|
||||
condition = lax.ne(condition, lax._const(condition, 0))
|
||||
x, y = _promote_dtypes(x, y)
|
||||
condition, x, y = _broadcast_arrays(condition, x, y)
|
||||
try: is_always_empty = core.is_empty_shape(np.shape(x))
|
||||
except: is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition, x, y) if not is_always_empty else x
|
||||
|
Loading…
x
Reference in New Issue
Block a user