mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
First pass at ufunc interfaces for several jax.numpy functions
This commit is contained in:
parent
db4be03f02
commit
a3d6cf007e
@ -27,6 +27,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
|
||||
and ``meta_fields`` includes all dataclass fields with ``init=True``
|
||||
and only them, if ``nodetype`` is a dataclass.
|
||||
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
|
||||
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
|
||||
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
|
||||
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
|
||||
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
|
||||
In future releases we plan to expand these to other ufuncs.
|
||||
|
||||
* Breaking changes
|
||||
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
||||
|
@ -30,7 +30,6 @@ from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy.util import (
|
||||
_broadcast_to, check_arraylike, _complex_elem_type,
|
||||
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
|
||||
@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
|
||||
a_shape = a.shape
|
||||
|
||||
if squash_nans:
|
||||
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
|
||||
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
|
||||
a = lax.sort(a, dimension=axis)
|
||||
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
|
||||
counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
|
||||
shape_after_reduction = counts.shape
|
||||
q = lax.expand_dims(
|
||||
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
|
||||
@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
|
||||
index[axis] = high
|
||||
high_value = a[tuple(index)]
|
||||
else:
|
||||
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a)
|
||||
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
|
||||
a = lax.sort(a, dimension=axis)
|
||||
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
|
||||
q = lax.mul(q, n - 1)
|
||||
@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
Array([1.5, 3. , 4.5], dtype=float32)
|
||||
"""
|
||||
check_arraylike("nanpercentile", a, q)
|
||||
q = ufuncs.true_divide(q, 100.0)
|
||||
q, = promote_dtypes_inexact(q)
|
||||
q = q / 100
|
||||
if not isinstance(interpolation, DeprecatedArg):
|
||||
deprecations.warn(
|
||||
"jax-numpy-quantile-interpolation",
|
||||
|
@ -25,13 +25,11 @@ from typing import Any
|
||||
import jax
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
|
||||
import jax._src.numpy.lax_numpy as jnp
|
||||
from jax._src.numpy.reductions import _moveaxis
|
||||
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
|
||||
from jax._src.numpy.util import check_arraylike, _broadcast_to, _where
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.util import canonicalize_axis, set_module
|
||||
from jax._src import pjit
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -42,81 +40,126 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
|
||||
"""
|
||||
|
||||
|
||||
def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None:
|
||||
"""
|
||||
If fun(*args) lowers to a single primitive with inputs and outputs matching
|
||||
function inputs and outputs, return that primitive. Otherwise return None.
|
||||
"""
|
||||
try:
|
||||
jaxpr = jax.make_jaxpr(fun)(*args)
|
||||
except:
|
||||
return None
|
||||
while len(jaxpr.eqns) == 1:
|
||||
eqn = jaxpr.eqns[0]
|
||||
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
|
||||
return None
|
||||
elif (eqn.primitive == pjit.pjit_p and
|
||||
all(pjit.is_unspecified(sharding) for sharding in
|
||||
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
|
||||
jaxpr = jaxpr.eqns[0].params['jaxpr']
|
||||
else:
|
||||
return jaxpr.eqns[0].primitive
|
||||
return None
|
||||
|
||||
|
||||
_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = {
|
||||
lax_internal.add_p: reductions.sum,
|
||||
lax_internal.mul_p: reductions.prod,
|
||||
}
|
||||
|
||||
|
||||
_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = {
|
||||
lax_internal.add_p: reductions.cumsum,
|
||||
lax_internal.mul_p: reductions.cumprod,
|
||||
}
|
||||
|
||||
|
||||
@set_module('jax.numpy')
|
||||
class ufunc:
|
||||
"""Functions that operate element-by-element on whole arrays.
|
||||
"""Universal functions which operation element-by-element on arrays.
|
||||
|
||||
This is a class for LAX-backed implementations of numpy ufuncs.
|
||||
JAX implementation of :class:`numpy.ufunc`.
|
||||
|
||||
This is a class for JAX-backed implementations of NumPy's ufunc APIs.
|
||||
Most users will never need to instantiate :class:`ufunc`, but rather
|
||||
will use the pre-defined ufuncs in :mod:`jax.numpy`.
|
||||
|
||||
For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`.
|
||||
|
||||
Examples:
|
||||
Universal functions are functions that apply element-wise to broadcasted
|
||||
arrays, but they also come with a number of extra attributes and methods.
|
||||
|
||||
As an example, consider the function :obj:`jax.numpy.add`. The object
|
||||
acts as a function that applies addition to broadcasted arrays in an
|
||||
element-wise manner:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 4, 5])
|
||||
>>> jnp.add(x, 1)
|
||||
Array([2, 3, 4, 5, 6], dtype=int32)
|
||||
|
||||
Each :class:`ufunc` object includes a number of attributes that describe
|
||||
its behavior:
|
||||
|
||||
>>> jnp.add.nin # number of inputs
|
||||
2
|
||||
>>> jnp.add.nout # number of outputs
|
||||
1
|
||||
>>> jnp.add.identity # identity value, or None if no identity exists
|
||||
0
|
||||
|
||||
Binary ufuncs like :obj:`jax.numpy.add` include number of methods to
|
||||
apply the function to arrays in different manners.
|
||||
|
||||
The :meth:`~ufunc.outer` method applies the function to the
|
||||
pair-wise outer-product of the input array values:
|
||||
|
||||
>>> jnp.add.outer(x, x)
|
||||
Array([[ 2, 3, 4, 5, 6],
|
||||
[ 3, 4, 5, 6, 7],
|
||||
[ 4, 5, 6, 7, 8],
|
||||
[ 5, 6, 7, 8, 9],
|
||||
[ 6, 7, 8, 9, 10]], dtype=int32)
|
||||
|
||||
The :meth:`ufunc.reduce` method perfoms a reduction over the array.
|
||||
For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``:
|
||||
|
||||
>>> jnp.add.reduce(x)
|
||||
Array(15, dtype=int32)
|
||||
|
||||
The :meth:`ufunc.accumulate` method performs a cumulative reduction
|
||||
over the array. For example, :meth:`jnp.add.accumulate` is equivalent
|
||||
to :func:`jax.numpy.cumulative_sum`:
|
||||
|
||||
>>> jnp.add.accumulate(x)
|
||||
Array([ 1, 3, 6, 10, 15], dtype=int32)
|
||||
|
||||
The :meth:`ufunc.at` method applies the function at particular indices in the
|
||||
array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`:
|
||||
|
||||
>>> jnp.add.at(x, 0, 100, inplace=False)
|
||||
Array([101, 2, 3, 4, 5], dtype=int32)
|
||||
|
||||
And the :meth:`ufunc.reduceat` method performs a number of ``reduce``
|
||||
operations bewteen specified indices of an array; for ``jnp.add`` the
|
||||
operation is similar to :func:`jax.ops.segment_sum`:
|
||||
|
||||
>>> jnp.add.reduceat(x, jnp.array([0, 2]))
|
||||
Array([ 3, 12], dtype=int32)
|
||||
|
||||
In this case, the first element is ``x[0:2].sum()``, and the second element
|
||||
is ``x[2:].sum()``.
|
||||
"""
|
||||
def __init__(self, func: Callable[..., Any], /,
|
||||
nin: int, nout: int, *,
|
||||
name: str | None = None,
|
||||
nargs: int | None = None,
|
||||
identity: Any = None, update_doc=False):
|
||||
identity: Any = None,
|
||||
call: Callable[..., Any] | None = None,
|
||||
reduce: Callable[..., Any] | None = None,
|
||||
accumulate: Callable[..., Any] | None = None,
|
||||
at: Callable[..., Any] | None = None,
|
||||
reduceat: Callable[..., Any] | None = None,
|
||||
):
|
||||
self.__doc__ = func.__doc__
|
||||
self.__name__ = name or func.__name__
|
||||
# We want ufunc instances to work properly when marked as static,
|
||||
# and for this reason it's important that their properties not be
|
||||
# mutated. We prevent this by storing them in a dunder attribute,
|
||||
# and accessing them via read-only properties.
|
||||
if update_doc:
|
||||
self.__doc__ = func.__doc__
|
||||
self.__name__ = name or func.__name__
|
||||
self.__static_props = {
|
||||
'func': func,
|
||||
'call': vectorize(func),
|
||||
'nin': operator.index(nin),
|
||||
'nout': operator.index(nout),
|
||||
'nargs': operator.index(nargs or nin),
|
||||
'identity': identity
|
||||
'identity': identity,
|
||||
'call': call,
|
||||
'reduce': reduce,
|
||||
'accumulate': accumulate,
|
||||
'at': at,
|
||||
'reduceat': reduceat,
|
||||
}
|
||||
|
||||
_func = property(lambda self: self.__static_props['func'])
|
||||
_call = property(lambda self: self.__static_props['call'])
|
||||
nin = property(lambda self: self.__static_props['nin'])
|
||||
nout = property(lambda self: self.__static_props['nout'])
|
||||
nargs = property(lambda self: self.__static_props['nargs'])
|
||||
identity = property(lambda self: self.__static_props['identity'])
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Do not include _call, because it is computed from _func.
|
||||
# In both __hash__ and __eq__, we do not consider call, reduce, etc.
|
||||
# because they are considered implementation details rather than
|
||||
# necessary parts of object identity.
|
||||
return hash((self._func, self.__name__, self.identity,
|
||||
self.nin, self.nout, self.nargs))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
# Do not include _call, because it is computed from _func.
|
||||
return isinstance(other, ufunc) and (
|
||||
(self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) ==
|
||||
(other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs))
|
||||
@ -124,20 +167,71 @@ class ufunc:
|
||||
def __repr__(self) -> str:
|
||||
return f"<jnp.ufunc '{self.__name__}'>"
|
||||
|
||||
def __call__(self, *args: ArrayLike,
|
||||
out: None = None, where: None = None,
|
||||
**kwargs: Any) -> Any:
|
||||
def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any:
|
||||
check_arraylike(self.__name__, *args)
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"out argument of {self}")
|
||||
if where is not None:
|
||||
raise NotImplementedError(f"where argument of {self}")
|
||||
return self._call(*args, **kwargs)
|
||||
call = self.__static_props['call'] or self._call_vectorized
|
||||
return call(*args)
|
||||
|
||||
@partial(jax.jit, static_argnames=['self'])
|
||||
def _call_vectorized(self, *args):
|
||||
return vectorize(self._func)(*args)
|
||||
|
||||
@implements(np.ufunc.reduce, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
|
||||
def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
def reduce(self, a: ArrayLike, axis: int = 0,
|
||||
dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
"""Reduction operation derived from a binary function.
|
||||
|
||||
JAX implementation of :meth:`numpy.ufunc.reduce`.
|
||||
|
||||
Args:
|
||||
a: Input array.
|
||||
axis: integer specifying the axis over which to reduce. default=0
|
||||
dtype: optionally specify the type of the output array.
|
||||
out: Unused by JAX
|
||||
keepdims: If True, reduced axes are left in the result with size 1.
|
||||
If False (default) then reduced axes are squeezed out.
|
||||
initial: int or array, Default=None. Initial value for the reduction.
|
||||
where: boolean mask, default=None. The elements to be used in the sum. Array
|
||||
should be broadcast compatible to the input.
|
||||
|
||||
Returns:
|
||||
array containing the result of the reduction operation.
|
||||
|
||||
Examples:
|
||||
Consider the following array:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
|
||||
:meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum`
|
||||
along ``axis=0``:
|
||||
|
||||
>>> jnp.add.reduce(x)
|
||||
Array([5, 7, 9], dtype=int32)
|
||||
>>> x.sum(0)
|
||||
Array([5, 7, 9], dtype=int32)
|
||||
|
||||
Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to
|
||||
:func:`jax.numpy.all`:
|
||||
|
||||
>>> jnp.logical_and.reduce(x > 2)
|
||||
Array([False, False, True], dtype=bool)
|
||||
>>> jnp.all(x > 2, axis=0)
|
||||
Array([False, False, True], dtype=bool)
|
||||
|
||||
Some reductions do not correspond to any built-in aggregation function;
|
||||
for example here is the reduction of :func:`jax.numpy.bitwise_or` along
|
||||
the first axis of ``x``:
|
||||
|
||||
>>> jnp.bitwise_or.reduce(x, axis=1)
|
||||
Array([3, 7], dtype=int32)
|
||||
"""
|
||||
check_arraylike(f"{self.__name__}.reduce", a)
|
||||
if self.nin != 2:
|
||||
raise ValueError("reduce only supported for binary ufuncs")
|
||||
@ -154,14 +248,10 @@ class ufunc:
|
||||
"so to use a where mask one has to specify 'initial'.")
|
||||
if lax_internal._dtype(where) != bool:
|
||||
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
|
||||
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
|
||||
if primitive is None:
|
||||
reducer = self._reduce_via_scan
|
||||
else:
|
||||
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
|
||||
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
|
||||
reduce = self.__static_props['reduce'] or self._reduce_via_scan
|
||||
return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
assert self.nin == 2 and self.nout == 1
|
||||
@ -202,9 +292,9 @@ class ufunc:
|
||||
|
||||
def body_fun(i, val):
|
||||
if where is None:
|
||||
return self._call(val, arr[i].astype(dtype))
|
||||
return self(val, arr[i].astype(dtype))
|
||||
else:
|
||||
return _where(where[i], self._call(val, arr[i].astype(dtype)), val)
|
||||
return _where(where[i], self(val, arr[i].astype(dtype)), val)
|
||||
|
||||
start_value: ArrayLike
|
||||
if initial is None:
|
||||
@ -221,22 +311,63 @@ class ufunc:
|
||||
result = result.reshape(final_shape)
|
||||
return result
|
||||
|
||||
@implements(np.ufunc.accumulate, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
|
||||
def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
out: None = None) -> Array:
|
||||
"""Accumulate operation derived from binary ufunc.
|
||||
|
||||
JAX implementation of :func:`numpy.ufunc.accumulate`.
|
||||
|
||||
Args:
|
||||
a: N-dimensional array over which to accumulate.
|
||||
axis: integer axis over which accumulation will be performed (default = 0)
|
||||
dtype: optionally specify the type of the output array.
|
||||
out: Unused by JAX
|
||||
|
||||
Returns:
|
||||
An array containing the accumulated result.
|
||||
|
||||
Examples:
|
||||
Consider the following array:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
|
||||
:meth:`jax.numpy.add.accumulate` is equivalent to
|
||||
:func:`jax.numpy.cumsum` along the specified axis:
|
||||
>>> jnp.add.accumulate(x, axis=1)
|
||||
Array([[ 1, 3, 6],
|
||||
[ 4, 9, 15]], dtype=int32)
|
||||
>>> jnp.cumsum(x, axis=1)
|
||||
Array([[ 1, 3, 6],
|
||||
[ 4, 9, 15]], dtype=int32)
|
||||
|
||||
Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to
|
||||
:func:`jax.numpy.cumprod` along the specified axis:
|
||||
|
||||
>>> jnp.multiply.accumulate(x, axis=1)
|
||||
Array([[ 1, 2, 6],
|
||||
[ 4, 20, 120]], dtype=int32)
|
||||
>>> jnp.cumprod(x, axis=1)
|
||||
Array([[ 1, 2, 6],
|
||||
[ 4, 20, 120]], dtype=int32)
|
||||
|
||||
For other binary ufuncs, the accumulation is an operation not available
|
||||
via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate`
|
||||
is essentially a bitwise cumulative ``any``:
|
||||
|
||||
>>> jnp.bitwise_or.accumulate(x, axis=1)
|
||||
Array([[1, 3, 3],
|
||||
[4, 5, 7]], dtype=int32)
|
||||
"""
|
||||
if self.nin != 2:
|
||||
raise ValueError("accumulate only supported for binary ufuncs")
|
||||
if self.nout != 1:
|
||||
raise ValueError("accumulate only supported for functions returning a single value")
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
|
||||
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
|
||||
if primitive is None:
|
||||
accumulator = self._accumulate_via_scan
|
||||
else:
|
||||
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
|
||||
return accumulator(a, axis=axis, dtype=dtype)
|
||||
accumulate = self.__static_props['accumulate'] or self._accumulate_via_scan
|
||||
return accumulate(a, axis=axis, dtype=dtype)
|
||||
|
||||
def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0,
|
||||
dtype: DTypeLike | None = None) -> Array:
|
||||
@ -254,21 +385,54 @@ class ufunc:
|
||||
arr = _moveaxis(arr, axis, 0)
|
||||
def scan_fun(carry, _):
|
||||
i, x = carry
|
||||
y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype)))
|
||||
y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype)))
|
||||
return (i + 1, y), y
|
||||
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
|
||||
return _moveaxis(result, 0, axis)
|
||||
|
||||
@implements(np.ufunc.at, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
|
||||
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
|
||||
inplace: bool = True) -> Array:
|
||||
"""Update elements of an array via the specified unary or binary ufunc.
|
||||
|
||||
JAX implementation of :func:`numpy.ufunc.at`.
|
||||
|
||||
Note:
|
||||
:meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable,
|
||||
so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX
|
||||
will return the updated value, but requires explicitly passing ``inplace=False``
|
||||
as a reminder of this difference.
|
||||
|
||||
Args:
|
||||
a: N-dimensional array to update
|
||||
indices: index, slice, or tuple of indices and slices.
|
||||
b: array of values for binary ufunc updates.
|
||||
inplace: must be set to False to indicate that an updated copy will be returned.
|
||||
|
||||
Returns:
|
||||
an updated copy of the input array.
|
||||
|
||||
Examples:
|
||||
|
||||
Add numbers to specified indices:
|
||||
|
||||
>>> x = jnp.ones(10, dtype=int)
|
||||
>>> indices = jnp.array([2, 5, 7])
|
||||
>>> values = jnp.array([10, 20, 30])
|
||||
>>> jnp.add.at(x, indices, values, inplace=False)
|
||||
Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32)
|
||||
|
||||
This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method
|
||||
called this way:
|
||||
|
||||
>>> x.at[indices].add(values)
|
||||
Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32)
|
||||
"""
|
||||
if inplace:
|
||||
raise NotImplementedError(_AT_INPLACE_WARNING)
|
||||
if b is None:
|
||||
return self._at_via_scan(a, indices)
|
||||
else:
|
||||
return self._at_via_scan(a, indices, b)
|
||||
|
||||
at = self.__static_props['at'] or self._at_via_scan
|
||||
return at(a, indices) if b is None else at(a, indices, b)
|
||||
|
||||
def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array:
|
||||
assert len(args) in {0, 1}
|
||||
@ -276,14 +440,14 @@ class ufunc:
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
|
||||
a = lax_internal.asarray(a).astype(dtype)
|
||||
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
|
||||
indices = _eliminate_deprecated_list_indexing(indices)
|
||||
indices = jnp._eliminate_deprecated_list_indexing(indices)
|
||||
if not indices:
|
||||
return a
|
||||
|
||||
shapes = [np.shape(i) for i in indices if not isinstance(i, slice)]
|
||||
shape = shapes and jax.lax.broadcast_shapes(*shapes)
|
||||
if not shape:
|
||||
return a.at[indices].set(self._call(a.at[indices].get(), *args))
|
||||
return a.at[indices].set(self(a.at[indices].get(), *args))
|
||||
|
||||
if args:
|
||||
arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):]))
|
||||
@ -293,28 +457,65 @@ class ufunc:
|
||||
def scan_fun(carry, x):
|
||||
i, a = carry
|
||||
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
|
||||
a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args)))
|
||||
a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args)))
|
||||
return (i + 1, a), x
|
||||
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
|
||||
return carry[1]
|
||||
|
||||
@implements(np.ufunc.reduceat, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
|
||||
def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0,
|
||||
dtype: DTypeLike | None = None, out: None = None) -> Array:
|
||||
"""Reduce an array between specified indices via a binary ufunc.
|
||||
|
||||
JAX implementation of :meth:`numpy.ufunc.reduceat`
|
||||
|
||||
Args:
|
||||
a: N-dimensional array to reduce
|
||||
indices: a 1-dimensional array of increasing integer values which encodes
|
||||
segments of the array to be reduced.
|
||||
axis: integer specifying the axis along which to reduce: default=0.
|
||||
dtype: optionally specify the dtype of the output array.
|
||||
out: unused by JAX
|
||||
Returns:
|
||||
An array containing the reduced values.
|
||||
|
||||
Examples:
|
||||
The ``reduce`` method lets you efficiently compute reduction operations
|
||||
over array segments. For example:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
>>> indices = jnp.array([0, 2, 5])
|
||||
>>> jnp.add.reduce(x, indices)
|
||||
Array([ 3, 12, 21], dtype=int32)
|
||||
|
||||
This is more-or-less equivalent to the following:
|
||||
|
||||
>>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()])
|
||||
Array([ 3, 12, 21], dtype=int32)
|
||||
|
||||
For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`.
|
||||
For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`,
|
||||
although in this case the segments are defined via an array of segment ids:
|
||||
|
||||
>>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2])
|
||||
>>> jax.ops.segment_sum(x, segments)
|
||||
Array([ 3, 12, 21], dtype=int32)
|
||||
"""
|
||||
if self.nin != 2:
|
||||
raise ValueError("reduceat only supported for binary ufuncs")
|
||||
if self.nout != 1:
|
||||
raise ValueError("reduceat only supported for functions returning a single value")
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"out argument of {self.__name__}.reduceat()")
|
||||
return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype)
|
||||
|
||||
reduceat = self.__static_props['reduceat'] or self._reduceat_via_scan
|
||||
return reduceat(a, indices, axis=axis, dtype=dtype)
|
||||
|
||||
def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0,
|
||||
dtype: DTypeLike | None = None) -> Array:
|
||||
check_arraylike(f"{self.__name__}.reduceat", a, indices)
|
||||
a = lax_internal.asarray(a)
|
||||
idx_tuple = _eliminate_deprecated_list_indexing(indices)
|
||||
idx_tuple = jnp._eliminate_deprecated_list_indexing(indices)
|
||||
assert len(idx_tuple) == 1
|
||||
indices = idx_tuple[0]
|
||||
if a.ndim == 0:
|
||||
@ -326,27 +527,62 @@ class ufunc:
|
||||
if axis is None or isinstance(axis, (tuple, list)):
|
||||
raise ValueError("reduceat requires a single integer axis.")
|
||||
axis = canonicalize_axis(axis, a.ndim)
|
||||
out = take(a, indices, axis=axis)
|
||||
ind = jax.lax.expand_dims(append(indices, a.shape[axis]),
|
||||
out = jnp.take(a, indices, axis=axis)
|
||||
ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]),
|
||||
list(np.delete(np.arange(out.ndim), axis)))
|
||||
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
|
||||
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
|
||||
def loop_body(i, out):
|
||||
return _where((i > ind_start) & (i < ind_end),
|
||||
self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
|
||||
self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
|
||||
out)
|
||||
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
|
||||
|
||||
@implements(np.ufunc.outer, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnums=[0])
|
||||
def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array:
|
||||
def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array:
|
||||
"""Apply the function to all pairs of values in ``A`` and ``B``.
|
||||
|
||||
JAX implementation of :meth:`numpy.ufunc.outer`.
|
||||
|
||||
Args:
|
||||
A: N-dimensional array
|
||||
B: N-dimensional array
|
||||
|
||||
Returns:
|
||||
An array of shape `tuple(*A.shape, *B.shape)`
|
||||
|
||||
Examples:
|
||||
A times-table for integers 1...10 created via
|
||||
:meth:`jax.numpy.multiply.outer`:
|
||||
|
||||
>>> x = jnp.arange(1, 11)
|
||||
>>> print(jnp.multiply.outer(x, x))
|
||||
[[ 1 2 3 4 5 6 7 8 9 10]
|
||||
[ 2 4 6 8 10 12 14 16 18 20]
|
||||
[ 3 6 9 12 15 18 21 24 27 30]
|
||||
[ 4 8 12 16 20 24 28 32 36 40]
|
||||
[ 5 10 15 20 25 30 35 40 45 50]
|
||||
[ 6 12 18 24 30 36 42 48 54 60]
|
||||
[ 7 14 21 28 35 42 49 56 63 70]
|
||||
[ 8 16 24 32 40 48 56 64 72 80]
|
||||
[ 9 18 27 36 45 54 63 72 81 90]
|
||||
[ 10 20 30 40 50 60 70 80 90 100]]
|
||||
|
||||
For input arrays with ``N`` and ``M`` dimensions respectively, the output
|
||||
will have dimesion ``N + M``:
|
||||
|
||||
>>> x = jnp.ones((1, 3, 5))
|
||||
>>> y = jnp.ones((2, 4))
|
||||
>>> jnp.add.outer(x, y).shape
|
||||
(1, 3, 5, 2, 4)
|
||||
"""
|
||||
if self.nin != 2:
|
||||
raise ValueError("outer only supported for binary ufuncs")
|
||||
if self.nout != 1:
|
||||
raise ValueError("outer only supported for functions returning a single value")
|
||||
check_arraylike(f"{self.__name__}.outer", A, B)
|
||||
_ravel = lambda A: jax.lax.reshape(A, (np.size(A),))
|
||||
result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B))
|
||||
result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B))
|
||||
return result.reshape(*np.shape(A), *np.shape(B))
|
||||
|
||||
|
||||
@ -363,4 +599,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
|
||||
Returns:
|
||||
wrapped : jax.numpy.ufunc wrapper of func.
|
||||
"""
|
||||
return ufunc(func, nin, nout, identity=identity, update_doc=True)
|
||||
return ufunc(func, nin, nout, identity=identity)
|
||||
|
@ -30,11 +30,13 @@ from jax._src.api import jit
|
||||
from jax._src.custom_derivatives import custom_jvp
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import other as lax_other
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, promote_args, promote_args_inexact,
|
||||
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
|
||||
promote_shapes, _where, implements, check_no_float0s)
|
||||
from jax._src.numpy.ufunc_api import ufunc
|
||||
from jax._src.numpy import reductions
|
||||
|
||||
_lax_const = lax._const
|
||||
|
||||
@ -298,31 +300,81 @@ def sqrt(x: ArrayLike, /) -> Array:
|
||||
def cbrt(x: ArrayLike, /) -> Array:
|
||||
return lax.cbrt(*promote_args_inexact('cbrt', x))
|
||||
|
||||
@implements(np.add, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def add(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Add two arrays element-wise.
|
||||
|
||||
JAX implementation of :obj:`numpy.add`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: arrays to add. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise addition.
|
||||
"""
|
||||
x, y = promote_args("add", x, y)
|
||||
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
|
||||
|
||||
@implements(np.multiply, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Multiply two arrays element-wise.
|
||||
|
||||
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: arrays to multiply. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise multiplication.
|
||||
"""
|
||||
x, y = promote_args("multiply", x, y)
|
||||
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
|
||||
|
||||
@implements(np.bitwise_and, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Compute the bitwise AND operation elementwise.
|
||||
|
||||
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise bitwise AND.
|
||||
"""
|
||||
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
|
||||
|
||||
@implements(np.bitwise_or, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Compute the bitwise OR operation elementwise.
|
||||
|
||||
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise bitwise OR.
|
||||
"""
|
||||
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
|
||||
|
||||
@implements(np.bitwise_xor, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Compute the bitwise XOR operation elementwise.
|
||||
|
||||
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise bitwise XOR.
|
||||
"""
|
||||
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
|
||||
|
||||
@implements(np.left_shift, module='numpy')
|
||||
@ -376,19 +428,49 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
|
||||
|
||||
# Logical ops
|
||||
@implements(np.logical_and, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Compute the logical AND operation elementwise.
|
||||
|
||||
JAX implementation of :obj:`numpy.logical_and`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: input arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise logical AND.
|
||||
"""
|
||||
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
|
||||
|
||||
@implements(np.logical_or, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Compute the logical OR operation elementwise.
|
||||
|
||||
JAX implementation of :obj:`numpy.logical_or`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: input arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise logical OR.
|
||||
"""
|
||||
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
|
||||
|
||||
@implements(np.logical_xor, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Compute the logical XOR operation elementwise.
|
||||
|
||||
JAX implementation of :obj:`numpy.logical_xor`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
|
||||
Args:
|
||||
x, y: input arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise logical XOR.
|
||||
"""
|
||||
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
|
||||
|
||||
@implements(np.logical_not, module='numpy')
|
||||
@ -1281,3 +1363,38 @@ def _sinc_maclaurin(k, x):
|
||||
def _sinc_maclaurin_jvp(k, primals, tangents):
|
||||
(x,), (t,) = primals, tangents
|
||||
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t
|
||||
|
||||
|
||||
def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None):
|
||||
if initial is not None:
|
||||
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
|
||||
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
|
||||
return result if dtype is None else result.astype(dtype)
|
||||
|
||||
|
||||
def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None):
|
||||
if initial is not None:
|
||||
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
|
||||
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
|
||||
return result if dtype is None else result.astype(dtype)
|
||||
|
||||
|
||||
# Generate ufunc interfaces for several common binary functions.
|
||||
# We start with binary ufuncs that have well-defined identities.'
|
||||
# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience?
|
||||
# TODO(jakevdp): optimize some implementations.
|
||||
# - define add.at/multiply.at in terms of scatter_add/scatter_mul
|
||||
# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
|
||||
# - define all monoidal reductions in terms of lax.reduce
|
||||
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum)
|
||||
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod)
|
||||
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and)
|
||||
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or)
|
||||
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor)
|
||||
logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce)
|
||||
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
|
||||
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)
|
||||
|
@ -965,8 +965,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
|
||||
else:
|
||||
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
|
||||
if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def:
|
||||
self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def)
|
||||
if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def:
|
||||
self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def)
|
||||
|
||||
def test_bfloat16_constant(self):
|
||||
# Re: https://github.com/google/jax/issues/3942
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal, NamedTuple, TypeVar, Union, overload
|
||||
from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload
|
||||
|
||||
from jax._src import core as _core
|
||||
from jax._src import dtypes as _dtypes
|
||||
@ -28,6 +28,34 @@ _Device = Device
|
||||
|
||||
ComplexWarning: type
|
||||
|
||||
class BinaryUfunc(Protocol):
|
||||
@property
|
||||
def nin(self) -> int: ...
|
||||
@property
|
||||
def nout(self) -> int: ...
|
||||
@property
|
||||
def nargs(self) -> int: ...
|
||||
@property
|
||||
def identity(self) -> builtins.bool | int | float: ...
|
||||
def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def reduce(self, arr: ArrayLike, /, *,
|
||||
axis: int | None = 0,
|
||||
dtype: DTypeLike | None = None,
|
||||
keepdims: builtins.bool = False,
|
||||
initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array: ...
|
||||
def accumulate(self, a: ArrayLike, /, *,
|
||||
axis: int = 0,
|
||||
dtype: DTypeLike | None = None,
|
||||
out: None = None) -> Array: ...
|
||||
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
|
||||
inplace: builtins.bool = True) -> Array: ...
|
||||
def reduceat(self, a: ArrayLike, indices: Any, *,
|
||||
axis: int = 0,
|
||||
dtype: DTypeLike | None = None,
|
||||
out: None = None) -> Array: ...
|
||||
def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ...
|
||||
|
||||
__array_api_version__: str
|
||||
def __array_namespace_info__() -> ArrayNamespaceInfo: ...
|
||||
|
||||
@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ...
|
||||
def absolute(x: ArrayLike, /) -> Array: ...
|
||||
def acos(x: ArrayLike, /) -> Array: ...
|
||||
def acosh(x: ArrayLike, /) -> Array: ...
|
||||
def add(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
add: BinaryUfunc
|
||||
def amax(a: ArrayLike, axis: _Axis = ..., out: None = ...,
|
||||
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
|
||||
where: ArrayLike | None = ...) -> Array: ...
|
||||
@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ...
|
||||
bfloat16: Any
|
||||
def bincount(x: ArrayLike, weights: ArrayLike | None = ...,
|
||||
minlength: int = ..., *, length: int | None = ...) -> Array: ...
|
||||
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
bitwise_and: BinaryUfunc
|
||||
def bitwise_count(x: ArrayLike, /) -> Array: ...
|
||||
def bitwise_invert(x: ArrayLike, /) -> Array: ...
|
||||
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def bitwise_not(x: ArrayLike, /) -> Array: ...
|
||||
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
bitwise_or: BinaryUfunc
|
||||
def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
bitwise_xor: BinaryUfunc
|
||||
def blackman(M: int) -> Array: ...
|
||||
def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ...
|
||||
bool: Any
|
||||
@ -251,7 +279,7 @@ def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
||||
out: None = ...) -> Array: ...
|
||||
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
|
||||
dtype: DTypeLike | None = ...,
|
||||
include_initial: bool = ...) -> Array: ...
|
||||
include_initial: builtins.bool = ...) -> Array: ...
|
||||
|
||||
def deg2rad(x: ArrayLike, /) -> Array: ...
|
||||
degrees = rad2deg
|
||||
@ -557,10 +585,10 @@ def log1p(x: ArrayLike, /) -> Array: ...
|
||||
def log2(x: ArrayLike, /) -> Array: ...
|
||||
def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
logical_and: BinaryUfunc
|
||||
def logical_not(x: ArrayLike, /) -> Array: ...
|
||||
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
logical_or: BinaryUfunc
|
||||
logical_xor: BinaryUfunc
|
||||
def logspace(start: ArrayLike, stop: ArrayLike, num: int = ...,
|
||||
endpoint: builtins.bool = ..., base: ArrayLike = ...,
|
||||
dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ...
|
||||
@ -588,7 +616,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ...
|
||||
def moveaxis(a: ArrayLike, source: int | Sequence[int],
|
||||
destination: int | Sequence[int]) -> Array: ...
|
||||
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
multiply: BinaryUfunc
|
||||
nan: float
|
||||
def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ...,
|
||||
posinf: ArrayLike | None = ...,
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
"""Tests for jax.numpy.ufunc and its methods."""
|
||||
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -22,7 +23,6 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.numpy.ufunc_api import get_if_single_primitive
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -54,18 +54,21 @@ SCALAR_FUNCS = [
|
||||
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
|
||||
]
|
||||
|
||||
FASTPATH_FUNCS = [
|
||||
{'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0,
|
||||
'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p},
|
||||
{'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1,
|
||||
'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p},
|
||||
def _jnp_ufunc_props(name):
|
||||
jnp_func = getattr(jnp, name)
|
||||
assert isinstance(jnp_func, jnp.ufunc)
|
||||
np_func = getattr(np, name)
|
||||
dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types]
|
||||
return [dict(name=name, dtype=dtype) for dtype in dtypes]
|
||||
|
||||
|
||||
JAX_NUMPY_UFUNCS = [
|
||||
name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc)
|
||||
]
|
||||
|
||||
NON_FASTPATH_FUNCS = [
|
||||
{'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0},
|
||||
{'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1},
|
||||
{'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1},
|
||||
]
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
|
||||
_jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS
|
||||
))
|
||||
|
||||
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
|
||||
nonscalar_shapes = [(3,), (4,), (4, 3)]
|
||||
@ -80,23 +83,40 @@ def cast_outputs(fun):
|
||||
class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(SCALAR_FUNCS)
|
||||
def test_ufunc_properties(self, func, nin, nout, identity):
|
||||
def test_frompyfunc_properties(self, func, nin, nout, identity):
|
||||
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
|
||||
self.assertEqual(jnp_fun.identity, identity)
|
||||
self.assertEqual(jnp_fun.nin, nin)
|
||||
self.assertEqual(jnp_fun.nout, nout)
|
||||
self.assertEqual(jnp_fun.nargs, nin)
|
||||
|
||||
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
|
||||
def test_ufunc_properties(self, name):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
self.assertEqual(jnp_fun.identity, np_fun.identity)
|
||||
self.assertEqual(jnp_fun.nin, np_fun.nin)
|
||||
self.assertEqual(jnp_fun.nout, np_fun.nout)
|
||||
self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out`
|
||||
|
||||
@jtu.sample_product(SCALAR_FUNCS)
|
||||
def test_ufunc_properties_readonly(self, func, nin, nout, identity):
|
||||
def test_frompyfunc_properties_readonly(self, func, nin, nout, identity):
|
||||
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
|
||||
for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']:
|
||||
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
|
||||
getattr(jnp_fun, attr) # no error on attribute access.
|
||||
with self.assertRaises(AttributeError):
|
||||
setattr(jnp_fun, attr, None) # error when trying to mutate.
|
||||
|
||||
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
|
||||
def test_ufunc_properties_readonly(self, name):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
|
||||
getattr(jnp_fun, attr) # no error on attribute access.
|
||||
with self.assertRaises(AttributeError):
|
||||
setattr(jnp_fun, attr, None) # error when trying to mutate.
|
||||
|
||||
@jtu.sample_product(SCALAR_FUNCS)
|
||||
def test_ufunc_hash(self, func, nin, nout, identity):
|
||||
def test_frompyfunc_hash(self, func, nin, nout, identity):
|
||||
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
|
||||
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
|
||||
self.assertEqual(jnp_fun, jnp_fun_2)
|
||||
@ -113,7 +133,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
|
||||
def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
|
||||
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
|
||||
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity))
|
||||
|
||||
@ -123,13 +143,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
lhs_shape=broadcast_compatible_shapes,
|
||||
rhs_shape=broadcast_compatible_shapes,
|
||||
)
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
lhs_shape=broadcast_compatible_shapes,
|
||||
rhs_shape=broadcast_compatible_shapes,
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
|
||||
def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer
|
||||
@ -141,6 +176,23 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
lhs_shape=broadcast_compatible_shapes,
|
||||
rhs_shape=broadcast_compatible_shapes,
|
||||
)
|
||||
def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
|
||||
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker)
|
||||
self._CompileAndCheck(jnp_fun.outer, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
@ -148,7 +200,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
for axis in [None, *range(-len(shape), len(shape))]],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_reduce(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
|
||||
@ -160,6 +212,26 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in [None, *range(-len(shape), len(shape))]],
|
||||
)
|
||||
def test_ufunc_reduce(self, name, shape, axis, dtype):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
|
||||
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
|
||||
np_fun_reduce = partial(np_fun.reduce, axis=axis)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
|
||||
self._CompileAndCheck(jnp_fun_reduce, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
@ -167,7 +239,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
for axis in [None, *range(-len(shape), len(shape))]],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
|
||||
@ -194,42 +266,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
FASTPATH_FUNCS,
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
for axis in [None, *range(-len(shape), len(shape))]],
|
||||
)
|
||||
def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
|
||||
del accumulator # unused
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
def test_ufunc_reduce_where(self, name, shape, axis, dtype):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
|
||||
if jnp_fun.identity is None:
|
||||
self.skipTest("reduce with where requires identity")
|
||||
|
||||
jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where)
|
||||
np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
|
||||
self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer)
|
||||
|
||||
@jtu.sample_product(
|
||||
NON_FASTPATH_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
|
||||
_ = func(0, 0) # function should not error.
|
||||
|
||||
reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
|
||||
self.assertIsNone(get_if_single_primitive(reduce_fun, *args))
|
||||
|
||||
accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
|
||||
self.assertIsNone(get_if_single_primitive(accum_fun, *args))
|
||||
rng_where = jtu.rand_bool(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
|
||||
self._CompileAndCheck(jnp_fun_reduce, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
@ -238,7 +296,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
|
||||
@ -251,20 +309,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
FASTPATH_FUNCS,
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
for axis in range(-len(shape), len(shape))]
|
||||
)
|
||||
def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
|
||||
del reducer # unused
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
def test_ufunc_accumulate(self, name, shape, axis, dtype):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
|
||||
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
|
||||
self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis)
|
||||
def np_fun_accumulate(x):
|
||||
# numpy accumulate has different dtype casting behavior.
|
||||
result = np_fun.accumulate(x, axis=axis)
|
||||
return result if x.dtype == bool else result.astype(x.dtype)
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker)
|
||||
self._CompileAndCheck(jnp_fun_accumulate, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
@ -272,7 +338,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
idx_shape=[(), (2,)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
|
||||
def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False)
|
||||
@ -288,7 +354,31 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def test_at_broadcasting(self):
|
||||
@jtu.sample_product(
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
shape=nonscalar_shapes,
|
||||
idx_shape=[(), (2,)],
|
||||
)
|
||||
def test_ufunc_at(self, name, shape, idx_shape, dtype):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
|
||||
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
|
||||
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
|
||||
|
||||
jnp_fun_at = partial(jnp_fun.at, inplace=False)
|
||||
def np_fun_at(x, idx, y):
|
||||
x_copy = x.copy()
|
||||
np_fun.at(x_copy, idx, y)
|
||||
return x_copy
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker)
|
||||
self._CompileAndCheck(jnp_fun_at, args_maker)
|
||||
|
||||
def test_frompyfunc_at_broadcasting(self):
|
||||
# Regression test for https://github.com/google/jax/issues/18004
|
||||
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
|
||||
np.arange(9.0).reshape(3, 3)]
|
||||
@ -309,7 +399,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
idx_shape=[(0,), (3,), (5,)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
|
||||
def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis)
|
||||
@ -322,6 +412,33 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
JAX_NUMPY_UFUNCS_WITH_DTYPES,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in [*range(-len(shape), len(shape))]],
|
||||
idx_shape=[(0,), (3,), (5,)],
|
||||
)
|
||||
def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
|
||||
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
|
||||
if name in ['add', 'multiply'] and dtype == bool:
|
||||
# TODO(jakevdp): figure out how to fix thest cases.
|
||||
self.skipTest(f"known failure for {name}.reduceat with {dtype=}")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
|
||||
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
|
||||
|
||||
def np_fun_reduceat(x, i):
|
||||
# Numpy has different casting behavior.
|
||||
return np_fun.reduceat(x, i).astype(x.dtype)
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker)
|
||||
self._CompileAndCheck(jnp_fun.reduceat, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user