First pass at ufunc interfaces for several jax.numpy functions

This commit is contained in:
Jake VanderPlas 2024-08-30 11:53:02 -07:00
parent db4be03f02
commit a3d6cf007e
7 changed files with 692 additions and 188 deletions

View File

@ -27,6 +27,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields`` * ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True`` and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass. and only them, if ``nodetype`` is a dataclass.
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Breaking changes * Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the

View File

@ -30,7 +30,6 @@ from jax._src import api
from jax._src import core from jax._src import core
from jax._src import deprecations from jax._src import deprecations
from jax._src import dtypes from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import ( from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type, _broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
a_shape = a.shape a_shape = a.shape
if squash_nans: if squash_nans:
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis) a = lax.sort(a, dimension=axis)
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
shape_after_reduction = counts.shape shape_after_reduction = counts.shape
q = lax.expand_dims( q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
index[axis] = high index[axis] = high
high_value = a[tuple(index)] high_value = a[tuple(index)]
else: else:
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis) a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1) q = lax.mul(q, n - 1)
@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
Array([1.5, 3. , 4.5], dtype=float32) Array([1.5, 3. , 4.5], dtype=float32)
""" """
check_arraylike("nanpercentile", a, q) check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0) q, = promote_dtypes_inexact(q)
q = q / 100
if not isinstance(interpolation, DeprecatedArg): if not isinstance(interpolation, DeprecatedArg):
deprecations.warn( deprecations.warn(
"jax-numpy-quantile-interpolation", "jax-numpy-quantile-interpolation",

View File

@ -25,13 +25,11 @@ from typing import Any
import jax import jax
from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions import jax._src.numpy.lax_numpy as jnp
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
from jax._src.numpy.reductions import _moveaxis from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where from jax._src.numpy.util import check_arraylike, _broadcast_to, _where
from jax._src.numpy.vectorize import vectorize from jax._src.numpy.vectorize import vectorize
from jax._src.util import canonicalize_axis, set_module from jax._src.util import canonicalize_axis, set_module
from jax._src import pjit
import numpy as np import numpy as np
@ -42,81 +40,126 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
""" """
def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None:
"""
If fun(*args) lowers to a single primitive with inputs and outputs matching
function inputs and outputs, return that primitive. Otherwise return None.
"""
try:
jaxpr = jax.make_jaxpr(fun)(*args)
except:
return None
while len(jaxpr.eqns) == 1:
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == pjit.pjit_p and
all(pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:
return jaxpr.eqns[0].primitive
return None
_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = {
lax_internal.add_p: reductions.sum,
lax_internal.mul_p: reductions.prod,
}
_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = {
lax_internal.add_p: reductions.cumsum,
lax_internal.mul_p: reductions.cumprod,
}
@set_module('jax.numpy') @set_module('jax.numpy')
class ufunc: class ufunc:
"""Functions that operate element-by-element on whole arrays. """Universal functions which operation element-by-element on arrays.
This is a class for LAX-backed implementations of numpy ufuncs. JAX implementation of :class:`numpy.ufunc`.
This is a class for JAX-backed implementations of NumPy's ufunc APIs.
Most users will never need to instantiate :class:`ufunc`, but rather
will use the pre-defined ufuncs in :mod:`jax.numpy`.
For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`.
Examples:
Universal functions are functions that apply element-wise to broadcasted
arrays, but they also come with a number of extra attributes and methods.
As an example, consider the function :obj:`jax.numpy.add`. The object
acts as a function that applies addition to broadcasted arrays in an
element-wise manner:
>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)
Each :class:`ufunc` object includes a number of attributes that describe
its behavior:
>>> jnp.add.nin # number of inputs
2
>>> jnp.add.nout # number of outputs
1
>>> jnp.add.identity # identity value, or None if no identity exists
0
Binary ufuncs like :obj:`jax.numpy.add` include number of methods to
apply the function to arrays in different manners.
The :meth:`~ufunc.outer` method applies the function to the
pair-wise outer-product of the input array values:
>>> jnp.add.outer(x, x)
Array([[ 2, 3, 4, 5, 6],
[ 3, 4, 5, 6, 7],
[ 4, 5, 6, 7, 8],
[ 5, 6, 7, 8, 9],
[ 6, 7, 8, 9, 10]], dtype=int32)
The :meth:`ufunc.reduce` method perfoms a reduction over the array.
For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``:
>>> jnp.add.reduce(x)
Array(15, dtype=int32)
The :meth:`ufunc.accumulate` method performs a cumulative reduction
over the array. For example, :meth:`jnp.add.accumulate` is equivalent
to :func:`jax.numpy.cumulative_sum`:
>>> jnp.add.accumulate(x)
Array([ 1, 3, 6, 10, 15], dtype=int32)
The :meth:`ufunc.at` method applies the function at particular indices in the
array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`:
>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101, 2, 3, 4, 5], dtype=int32)
And the :meth:`ufunc.reduceat` method performs a number of ``reduce``
operations bewteen specified indices of an array; for ``jnp.add`` the
operation is similar to :func:`jax.ops.segment_sum`:
>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)
In this case, the first element is ``x[0:2].sum()``, and the second element
is ``x[2:].sum()``.
""" """
def __init__(self, func: Callable[..., Any], /, def __init__(self, func: Callable[..., Any], /,
nin: int, nout: int, *, nin: int, nout: int, *,
name: str | None = None, name: str | None = None,
nargs: int | None = None, nargs: int | None = None,
identity: Any = None, update_doc=False): identity: Any = None,
call: Callable[..., Any] | None = None,
reduce: Callable[..., Any] | None = None,
accumulate: Callable[..., Any] | None = None,
at: Callable[..., Any] | None = None,
reduceat: Callable[..., Any] | None = None,
):
self.__doc__ = func.__doc__
self.__name__ = name or func.__name__
# We want ufunc instances to work properly when marked as static, # We want ufunc instances to work properly when marked as static,
# and for this reason it's important that their properties not be # and for this reason it's important that their properties not be
# mutated. We prevent this by storing them in a dunder attribute, # mutated. We prevent this by storing them in a dunder attribute,
# and accessing them via read-only properties. # and accessing them via read-only properties.
if update_doc:
self.__doc__ = func.__doc__
self.__name__ = name or func.__name__
self.__static_props = { self.__static_props = {
'func': func, 'func': func,
'call': vectorize(func),
'nin': operator.index(nin), 'nin': operator.index(nin),
'nout': operator.index(nout), 'nout': operator.index(nout),
'nargs': operator.index(nargs or nin), 'nargs': operator.index(nargs or nin),
'identity': identity 'identity': identity,
'call': call,
'reduce': reduce,
'accumulate': accumulate,
'at': at,
'reduceat': reduceat,
} }
_func = property(lambda self: self.__static_props['func']) _func = property(lambda self: self.__static_props['func'])
_call = property(lambda self: self.__static_props['call'])
nin = property(lambda self: self.__static_props['nin']) nin = property(lambda self: self.__static_props['nin'])
nout = property(lambda self: self.__static_props['nout']) nout = property(lambda self: self.__static_props['nout'])
nargs = property(lambda self: self.__static_props['nargs']) nargs = property(lambda self: self.__static_props['nargs'])
identity = property(lambda self: self.__static_props['identity']) identity = property(lambda self: self.__static_props['identity'])
def __hash__(self) -> int: def __hash__(self) -> int:
# Do not include _call, because it is computed from _func. # In both __hash__ and __eq__, we do not consider call, reduce, etc.
# because they are considered implementation details rather than
# necessary parts of object identity.
return hash((self._func, self.__name__, self.identity, return hash((self._func, self.__name__, self.identity,
self.nin, self.nout, self.nargs)) self.nin, self.nout, self.nargs))
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
# Do not include _call, because it is computed from _func.
return isinstance(other, ufunc) and ( return isinstance(other, ufunc) and (
(self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) == (self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) ==
(other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs)) (other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs))
@ -124,20 +167,71 @@ class ufunc:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<jnp.ufunc '{self.__name__}'>" return f"<jnp.ufunc '{self.__name__}'>"
def __call__(self, *args: ArrayLike, def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any:
out: None = None, where: None = None, check_arraylike(self.__name__, *args)
**kwargs: Any) -> Any:
if out is not None: if out is not None:
raise NotImplementedError(f"out argument of {self}") raise NotImplementedError(f"out argument of {self}")
if where is not None: if where is not None:
raise NotImplementedError(f"where argument of {self}") raise NotImplementedError(f"where argument of {self}")
return self._call(*args, **kwargs) call = self.__static_props['call'] or self._call_vectorized
return call(*args)
@partial(jax.jit, static_argnames=['self'])
def _call_vectorized(self, *args):
return vectorize(self._func)(*args)
@implements(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, def reduce(self, a: ArrayLike, axis: int = 0,
dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: where: ArrayLike | None = None) -> Array:
"""Reduction operation derived from a binary function.
JAX implementation of :meth:`numpy.ufunc.reduce`.
Args:
a: Input array.
axis: integer specifying the axis over which to reduce. default=0
dtype: optionally specify the type of the output array.
out: Unused by JAX
keepdims: If True, reduced axes are left in the result with size 1.
If False (default) then reduced axes are squeezed out.
initial: int or array, Default=None. Initial value for the reduction.
where: boolean mask, default=None. The elements to be used in the sum. Array
should be broadcast compatible to the input.
Returns:
array containing the result of the reduction operation.
Examples:
Consider the following array:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
:meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum`
along ``axis=0``:
>>> jnp.add.reduce(x)
Array([5, 7, 9], dtype=int32)
>>> x.sum(0)
Array([5, 7, 9], dtype=int32)
Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to
:func:`jax.numpy.all`:
>>> jnp.logical_and.reduce(x > 2)
Array([False, False, True], dtype=bool)
>>> jnp.all(x > 2, axis=0)
Array([False, False, True], dtype=bool)
Some reductions do not correspond to any built-in aggregation function;
for example here is the reduction of :func:`jax.numpy.bitwise_or` along
the first axis of ``x``:
>>> jnp.bitwise_or.reduce(x, axis=1)
Array([3, 7], dtype=int32)
"""
check_arraylike(f"{self.__name__}.reduce", a) check_arraylike(f"{self.__name__}.reduce", a)
if self.nin != 2: if self.nin != 2:
raise ValueError("reduce only supported for binary ufuncs") raise ValueError("reduce only supported for binary ufuncs")
@ -154,14 +248,10 @@ class ufunc:
"so to use a where mask one has to specify 'initial'.") "so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool: if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) reduce = self.__static_props['reduce'] or self._reduce_via_scan
if primitive is None: return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
reducer = self._reduce_via_scan
else:
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None,
keepdims: bool = False, initial: ArrayLike | None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: where: ArrayLike | None = None) -> Array:
assert self.nin == 2 and self.nout == 1 assert self.nin == 2 and self.nout == 1
@ -202,9 +292,9 @@ class ufunc:
def body_fun(i, val): def body_fun(i, val):
if where is None: if where is None:
return self._call(val, arr[i].astype(dtype)) return self(val, arr[i].astype(dtype))
else: else:
return _where(where[i], self._call(val, arr[i].astype(dtype)), val) return _where(where[i], self(val, arr[i].astype(dtype)), val)
start_value: ArrayLike start_value: ArrayLike
if initial is None: if initial is None:
@ -221,22 +311,63 @@ class ufunc:
result = result.reshape(final_shape) result = result.reshape(final_shape)
return result return result
@implements(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) @partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None) -> Array: out: None = None) -> Array:
"""Accumulate operation derived from binary ufunc.
JAX implementation of :func:`numpy.ufunc.accumulate`.
Args:
a: N-dimensional array over which to accumulate.
axis: integer axis over which accumulation will be performed (default = 0)
dtype: optionally specify the type of the output array.
out: Unused by JAX
Returns:
An array containing the accumulated result.
Examples:
Consider the following array:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
:meth:`jax.numpy.add.accumulate` is equivalent to
:func:`jax.numpy.cumsum` along the specified axis:
>>> jnp.add.accumulate(x, axis=1)
Array([[ 1, 3, 6],
[ 4, 9, 15]], dtype=int32)
>>> jnp.cumsum(x, axis=1)
Array([[ 1, 3, 6],
[ 4, 9, 15]], dtype=int32)
Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to
:func:`jax.numpy.cumprod` along the specified axis:
>>> jnp.multiply.accumulate(x, axis=1)
Array([[ 1, 2, 6],
[ 4, 20, 120]], dtype=int32)
>>> jnp.cumprod(x, axis=1)
Array([[ 1, 2, 6],
[ 4, 20, 120]], dtype=int32)
For other binary ufuncs, the accumulation is an operation not available
via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate`
is essentially a bitwise cumulative ``any``:
>>> jnp.bitwise_or.accumulate(x, axis=1)
Array([[1, 3, 3],
[4, 5, 7]], dtype=int32)
"""
if self.nin != 2: if self.nin != 2:
raise ValueError("accumulate only supported for binary ufuncs") raise ValueError("accumulate only supported for binary ufuncs")
if self.nout != 1: if self.nout != 1:
raise ValueError("accumulate only supported for functions returning a single value") raise ValueError("accumulate only supported for functions returning a single value")
if out is not None: if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) accumulate = self.__static_props['accumulate'] or self._accumulate_via_scan
if primitive is None: return accumulate(a, axis=axis, dtype=dtype)
accumulator = self._accumulate_via_scan
else:
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
return accumulator(a, axis=axis, dtype=dtype)
def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0,
dtype: DTypeLike | None = None) -> Array: dtype: DTypeLike | None = None) -> Array:
@ -254,21 +385,54 @@ class ufunc:
arr = _moveaxis(arr, axis, 0) arr = _moveaxis(arr, axis, 0)
def scan_fun(carry, _): def scan_fun(carry, _):
i, x = carry i, x = carry
y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype))) y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype)))
return (i + 1, y), y return (i + 1, y), y
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
return _moveaxis(result, 0, axis) return _moveaxis(result, 0, axis)
@implements(np.ufunc.at, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) @partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
inplace: bool = True) -> Array: inplace: bool = True) -> Array:
"""Update elements of an array via the specified unary or binary ufunc.
JAX implementation of :func:`numpy.ufunc.at`.
Note:
:meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable,
so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX
will return the updated value, but requires explicitly passing ``inplace=False``
as a reminder of this difference.
Args:
a: N-dimensional array to update
indices: index, slice, or tuple of indices and slices.
b: array of values for binary ufunc updates.
inplace: must be set to False to indicate that an updated copy will be returned.
Returns:
an updated copy of the input array.
Examples:
Add numbers to specified indices:
>>> x = jnp.ones(10, dtype=int)
>>> indices = jnp.array([2, 5, 7])
>>> values = jnp.array([10, 20, 30])
>>> jnp.add.at(x, indices, values, inplace=False)
Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32)
This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method
called this way:
>>> x.at[indices].add(values)
Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32)
"""
if inplace: if inplace:
raise NotImplementedError(_AT_INPLACE_WARNING) raise NotImplementedError(_AT_INPLACE_WARNING)
if b is None:
return self._at_via_scan(a, indices) at = self.__static_props['at'] or self._at_via_scan
else: return at(a, indices) if b is None else at(a, indices, b)
return self._at_via_scan(a, indices, b)
def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array:
assert len(args) in {0, 1} assert len(args) in {0, 1}
@ -276,14 +440,14 @@ class ufunc:
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype) a = lax_internal.asarray(a).astype(dtype)
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
indices = _eliminate_deprecated_list_indexing(indices) indices = jnp._eliminate_deprecated_list_indexing(indices)
if not indices: if not indices:
return a return a
shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] shapes = [np.shape(i) for i in indices if not isinstance(i, slice)]
shape = shapes and jax.lax.broadcast_shapes(*shapes) shape = shapes and jax.lax.broadcast_shapes(*shapes)
if not shape: if not shape:
return a.at[indices].set(self._call(a.at[indices].get(), *args)) return a.at[indices].set(self(a.at[indices].get(), *args))
if args: if args:
arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):])) arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):]))
@ -293,28 +457,65 @@ class ufunc:
def scan_fun(carry, x): def scan_fun(carry, x):
i, a = carry i, a = carry
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args))) a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args)))
return (i + 1, a), x return (i + 1, a), x
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
return carry[1] return carry[1]
@implements(np.ufunc.reduceat, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) @partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0,
dtype: DTypeLike | None = None, out: None = None) -> Array: dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Reduce an array between specified indices via a binary ufunc.
JAX implementation of :meth:`numpy.ufunc.reduceat`
Args:
a: N-dimensional array to reduce
indices: a 1-dimensional array of increasing integer values which encodes
segments of the array to be reduced.
axis: integer specifying the axis along which to reduce: default=0.
dtype: optionally specify the dtype of the output array.
out: unused by JAX
Returns:
An array containing the reduced values.
Examples:
The ``reduce`` method lets you efficiently compute reduction operations
over array segments. For example:
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
>>> indices = jnp.array([0, 2, 5])
>>> jnp.add.reduce(x, indices)
Array([ 3, 12, 21], dtype=int32)
This is more-or-less equivalent to the following:
>>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()])
Array([ 3, 12, 21], dtype=int32)
For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`.
For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`,
although in this case the segments are defined via an array of segment ids:
>>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2])
>>> jax.ops.segment_sum(x, segments)
Array([ 3, 12, 21], dtype=int32)
"""
if self.nin != 2: if self.nin != 2:
raise ValueError("reduceat only supported for binary ufuncs") raise ValueError("reduceat only supported for binary ufuncs")
if self.nout != 1: if self.nout != 1:
raise ValueError("reduceat only supported for functions returning a single value") raise ValueError("reduceat only supported for functions returning a single value")
if out is not None: if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.reduceat()") raise NotImplementedError(f"out argument of {self.__name__}.reduceat()")
return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype)
reduceat = self.__static_props['reduceat'] or self._reduceat_via_scan
return reduceat(a, indices, axis=axis, dtype=dtype)
def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0,
dtype: DTypeLike | None = None) -> Array: dtype: DTypeLike | None = None) -> Array:
check_arraylike(f"{self.__name__}.reduceat", a, indices) check_arraylike(f"{self.__name__}.reduceat", a, indices)
a = lax_internal.asarray(a) a = lax_internal.asarray(a)
idx_tuple = _eliminate_deprecated_list_indexing(indices) idx_tuple = jnp._eliminate_deprecated_list_indexing(indices)
assert len(idx_tuple) == 1 assert len(idx_tuple) == 1
indices = idx_tuple[0] indices = idx_tuple[0]
if a.ndim == 0: if a.ndim == 0:
@ -326,27 +527,62 @@ class ufunc:
if axis is None or isinstance(axis, (tuple, list)): if axis is None or isinstance(axis, (tuple, list)):
raise ValueError("reduceat requires a single integer axis.") raise ValueError("reduceat requires a single integer axis.")
axis = canonicalize_axis(axis, a.ndim) axis = canonicalize_axis(axis, a.ndim)
out = take(a, indices, axis=axis) out = jnp.take(a, indices, axis=axis)
ind = jax.lax.expand_dims(append(indices, a.shape[axis]), ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]),
list(np.delete(np.arange(out.ndim), axis))) list(np.delete(np.arange(out.ndim), axis)))
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
def loop_body(i, out): def loop_body(i, out):
return _where((i > ind_start) & (i < ind_end), return _where((i > ind_start) & (i < ind_end),
self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
out) out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
@implements(np.ufunc.outer, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0]) @partial(jax.jit, static_argnums=[0])
def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array: def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array:
"""Apply the function to all pairs of values in ``A`` and ``B``.
JAX implementation of :meth:`numpy.ufunc.outer`.
Args:
A: N-dimensional array
B: N-dimensional array
Returns:
An array of shape `tuple(*A.shape, *B.shape)`
Examples:
A times-table for integers 1...10 created via
:meth:`jax.numpy.multiply.outer`:
>>> x = jnp.arange(1, 11)
>>> print(jnp.multiply.outer(x, x))
[[ 1 2 3 4 5 6 7 8 9 10]
[ 2 4 6 8 10 12 14 16 18 20]
[ 3 6 9 12 15 18 21 24 27 30]
[ 4 8 12 16 20 24 28 32 36 40]
[ 5 10 15 20 25 30 35 40 45 50]
[ 6 12 18 24 30 36 42 48 54 60]
[ 7 14 21 28 35 42 49 56 63 70]
[ 8 16 24 32 40 48 56 64 72 80]
[ 9 18 27 36 45 54 63 72 81 90]
[ 10 20 30 40 50 60 70 80 90 100]]
For input arrays with ``N`` and ``M`` dimensions respectively, the output
will have dimesion ``N + M``:
>>> x = jnp.ones((1, 3, 5))
>>> y = jnp.ones((2, 4))
>>> jnp.add.outer(x, y).shape
(1, 3, 5, 2, 4)
"""
if self.nin != 2: if self.nin != 2:
raise ValueError("outer only supported for binary ufuncs") raise ValueError("outer only supported for binary ufuncs")
if self.nout != 1: if self.nout != 1:
raise ValueError("outer only supported for functions returning a single value") raise ValueError("outer only supported for functions returning a single value")
check_arraylike(f"{self.__name__}.outer", A, B) check_arraylike(f"{self.__name__}.outer", A, B)
_ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) _ravel = lambda A: jax.lax.reshape(A, (np.size(A),))
result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B)) result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B))
return result.reshape(*np.shape(A), *np.shape(B)) return result.reshape(*np.shape(A), *np.shape(B))
@ -363,4 +599,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
Returns: Returns:
wrapped : jax.numpy.ufunc wrapper of func. wrapped : jax.numpy.ufunc wrapper of func.
""" """
return ufunc(func, nin, nout, identity=identity, update_doc=True) return ufunc(func, nin, nout, identity=identity)

View File

@ -30,11 +30,13 @@ from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.lax import other as lax_other from jax._src.lax import other as lax_other
from jax._src.typing import Array, ArrayLike from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.numpy.util import ( from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact, check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, implements, check_no_float0s) promote_shapes, _where, implements, check_no_float0s)
from jax._src.numpy.ufunc_api import ufunc
from jax._src.numpy import reductions
_lax_const = lax._const _lax_const = lax._const
@ -298,31 +300,81 @@ def sqrt(x: ArrayLike, /) -> Array:
def cbrt(x: ArrayLike, /) -> Array: def cbrt(x: ArrayLike, /) -> Array:
return lax.cbrt(*promote_args_inexact('cbrt', x)) return lax.cbrt(*promote_args_inexact('cbrt', x))
@implements(np.add, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def add(x: ArrayLike, y: ArrayLike, /) -> Array: def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Add two arrays element-wise.
JAX implementation of :obj:`numpy.add`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: arrays to add. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise addition.
"""
x, y = promote_args("add", x, y) x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
@implements(np.multiply, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Multiply two arrays element-wise.
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: arrays to multiply. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise multiplication.
"""
x, y = promote_args("multiply", x, y) x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
@implements(np.bitwise_and, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise AND operation elementwise.
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise AND.
"""
return lax.bitwise_and(*promote_args("bitwise_and", x, y)) return lax.bitwise_and(*promote_args("bitwise_and", x, y))
@implements(np.bitwise_or, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise OR operation elementwise.
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise OR.
"""
return lax.bitwise_or(*promote_args("bitwise_or", x, y)) return lax.bitwise_or(*promote_args("bitwise_or", x, y))
@implements(np.bitwise_xor, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise XOR operation elementwise.
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise XOR.
"""
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
@implements(np.left_shift, module='numpy') @implements(np.left_shift, module='numpy')
@ -376,19 +428,49 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.nextafter(*promote_args_inexact("nextafter", x, y)) return lax.nextafter(*promote_args_inexact("nextafter", x, y))
# Logical ops # Logical ops
@implements(np.logical_and, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the logical AND operation elementwise.
JAX implementation of :obj:`numpy.logical_and`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: input arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise logical AND.
"""
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
@implements(np.logical_or, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the logical OR operation elementwise.
JAX implementation of :obj:`numpy.logical_or`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: input arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise logical OR.
"""
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
@implements(np.logical_xor, module='numpy')
@partial(jit, inline=True) @partial(jit, inline=True)
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the logical XOR operation elementwise.
JAX implementation of :obj:`numpy.logical_xor`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: input arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise logical XOR.
"""
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
@implements(np.logical_not, module='numpy') @implements(np.logical_not, module='numpy')
@ -1281,3 +1363,38 @@ def _sinc_maclaurin(k, x):
def _sinc_maclaurin_jvp(k, primals, tangents): def _sinc_maclaurin_jvp(k, primals, tangents):
(x,), (t,) = primals, tangents (x,), (t,) = primals, tangents
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t
def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None):
if initial is not None:
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)
def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None):
if initial is not None:
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)
# Generate ufunc interfaces for several common binary functions.
# We start with binary ufuncs that have well-defined identities.'
# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience?
# TODO(jakevdp): optimize some implementations.
# - define add.at/multiply.at in terms of scatter_add/scatter_mul
# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
# - define all monoidal reductions in terms of lax.reduce
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum)
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod)
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and)
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or)
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor)
logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce)
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)

View File

@ -965,8 +965,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
else: else:
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def:
self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def)
def test_bfloat16_constant(self): def test_bfloat16_constant(self):
# Re: https://github.com/google/jax/issues/3942 # Re: https://github.com/google/jax/issues/3942

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import builtins import builtins
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, Literal, NamedTuple, TypeVar, Union, overload from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload
from jax._src import core as _core from jax._src import core as _core
from jax._src import dtypes as _dtypes from jax._src import dtypes as _dtypes
@ -28,6 +28,34 @@ _Device = Device
ComplexWarning: type ComplexWarning: type
class BinaryUfunc(Protocol):
@property
def nin(self) -> int: ...
@property
def nout(self) -> int: ...
@property
def nargs(self) -> int: ...
@property
def identity(self) -> builtins.bool | int | float: ...
def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ...
def reduce(self, arr: ArrayLike, /, *,
axis: int | None = 0,
dtype: DTypeLike | None = None,
keepdims: builtins.bool = False,
initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
def accumulate(self, a: ArrayLike, /, *,
axis: int = 0,
dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
inplace: builtins.bool = True) -> Array: ...
def reduceat(self, a: ArrayLike, indices: Any, *,
axis: int = 0,
dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ...
__array_api_version__: str __array_api_version__: str
def __array_namespace_info__() -> ArrayNamespaceInfo: ... def __array_namespace_info__() -> ArrayNamespaceInfo: ...
@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ...
def absolute(x: ArrayLike, /) -> Array: ... def absolute(x: ArrayLike, /) -> Array: ...
def acos(x: ArrayLike, /) -> Array: ... def acos(x: ArrayLike, /) -> Array: ...
def acosh(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ...
def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... add: BinaryUfunc
def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., def amax(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ... where: ArrayLike | None = ...) -> Array: ...
@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ...
bfloat16: Any bfloat16: Any
def bincount(x: ArrayLike, weights: ArrayLike | None = ..., def bincount(x: ArrayLike, weights: ArrayLike | None = ...,
minlength: int = ..., *, length: int | None = ...) -> Array: ... minlength: int = ..., *, length: int | None = ...) -> Array: ...
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... bitwise_and: BinaryUfunc
def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_count(x: ArrayLike, /) -> Array: ...
def bitwise_invert(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ...
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_not(x: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ...
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... bitwise_or: BinaryUfunc
def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... bitwise_xor: BinaryUfunc
def blackman(M: int) -> Array: ... def blackman(M: int) -> Array: ...
def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ...
bool: Any bool: Any
@ -251,7 +279,7 @@ def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ...) -> Array: ... out: None = ...) -> Array: ...
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
dtype: DTypeLike | None = ..., dtype: DTypeLike | None = ...,
include_initial: bool = ...) -> Array: ... include_initial: builtins.bool = ...) -> Array: ...
def deg2rad(x: ArrayLike, /) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ...
degrees = rad2deg degrees = rad2deg
@ -557,10 +585,10 @@ def log1p(x: ArrayLike, /) -> Array: ...
def log2(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ...
def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... logical_and: BinaryUfunc
def logical_not(x: ArrayLike, /) -> Array: ... def logical_not(x: ArrayLike, /) -> Array: ...
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... logical_or: BinaryUfunc
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... logical_xor: BinaryUfunc
def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., def logspace(start: ArrayLike, stop: ArrayLike, num: int = ...,
endpoint: builtins.bool = ..., base: ArrayLike = ..., endpoint: builtins.bool = ..., base: ArrayLike = ...,
dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ...
@ -588,7 +616,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ...
def moveaxis(a: ArrayLike, source: int | Sequence[int], def moveaxis(a: ArrayLike, source: int | Sequence[int],
destination: int | Sequence[int]) -> Array: ... destination: int | Sequence[int]) -> Array: ...
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... multiply: BinaryUfunc
nan: float nan: float
def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ...,
posinf: ArrayLike | None = ..., posinf: ArrayLike | None = ...,

View File

@ -14,6 +14,7 @@
"""Tests for jax.numpy.ufunc and its methods.""" """Tests for jax.numpy.ufunc and its methods."""
import itertools
from functools import partial from functools import partial
from absl.testing import absltest from absl.testing import absltest
@ -22,7 +23,6 @@ import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.numpy.ufunc_api import get_if_single_primitive
jax.config.parse_flags_with_absl() jax.config.parse_flags_with_absl()
@ -54,18 +54,21 @@ SCALAR_FUNCS = [
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
] ]
FASTPATH_FUNCS = [ def _jnp_ufunc_props(name):
{'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0, jnp_func = getattr(jnp, name)
'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p}, assert isinstance(jnp_func, jnp.ufunc)
{'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1, np_func = getattr(np, name)
'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p}, dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types]
return [dict(name=name, dtype=dtype) for dtype in dtypes]
JAX_NUMPY_UFUNCS = [
name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc)
] ]
NON_FASTPATH_FUNCS = [ JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
{'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0}, _jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS
{'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1}, ))
{'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1},
]
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)]
@ -80,23 +83,40 @@ def cast_outputs(fun):
class LaxNumpyUfuncTests(jtu.JaxTestCase): class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(SCALAR_FUNCS) @jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties(self, func, nin, nout, identity): def test_frompyfunc_properties(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun.identity, identity) self.assertEqual(jnp_fun.identity, identity)
self.assertEqual(jnp_fun.nin, nin) self.assertEqual(jnp_fun.nin, nin)
self.assertEqual(jnp_fun.nout, nout) self.assertEqual(jnp_fun.nout, nout)
self.assertEqual(jnp_fun.nargs, nin) self.assertEqual(jnp_fun.nargs, nin)
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties(self, name):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
self.assertEqual(jnp_fun.identity, np_fun.identity)
self.assertEqual(jnp_fun.nin, np_fun.nin)
self.assertEqual(jnp_fun.nout, np_fun.nout)
self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out`
@jtu.sample_product(SCALAR_FUNCS) @jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties_readonly(self, func, nin, nout, identity): def test_frompyfunc_properties_readonly(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']: for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties_readonly(self, name):
jnp_fun = getattr(jnp, name)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access. getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate. setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(SCALAR_FUNCS) @jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_hash(self, func, nin, nout, identity): def test_frompyfunc_hash(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun, jnp_fun_2) self.assertEqual(jnp_fun, jnp_fun_2)
@ -113,7 +133,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity)) np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity))
@ -123,13 +143,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product( @jtu.sample_product(
SCALAR_FUNCS, SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes, lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes,
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
if (nin, nout) != (2, 1): if (nin, nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}") self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer
@ -141,6 +176,23 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker)
self._CompileAndCheck(jnp_fun.outer, args_maker)
@jtu.sample_product( @jtu.sample_product(
SCALAR_FUNCS, SCALAR_FUNCS,
[{'shape': shape, 'axis': axis} [{'shape': shape, 'axis': axis}
@ -148,7 +200,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
for axis in [None, *range(-len(shape), len(shape))]], for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1): if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
@ -160,6 +212,26 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_ufunc_reduce(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
np_fun_reduce = partial(np_fun.reduce, axis=axis)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product( @jtu.sample_product(
SCALAR_FUNCS, SCALAR_FUNCS,
[{'shape': shape, 'axis': axis} [{'shape': shape, 'axis': axis}
@ -167,7 +239,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
for axis in [None, *range(-len(shape), len(shape))]], for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1): if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
@ -194,42 +266,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product( @jtu.sample_product(
FASTPATH_FUNCS, JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis} [{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))], for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
) )
def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): def test_ufunc_reduce_where(self, name, shape, axis, dtype):
del accumulator # unused jnp_fun = getattr(jnp, name)
if (nin, nout) != (2, 1): np_fun = getattr(np, name)
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if jnp_fun.identity is None:
self.skipTest("reduce with where requires identity")
jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where)
np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where)
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),) rng_where = jtu.rand_bool(self.rng())
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer)
@jtu.sample_product(
NON_FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
_ = func(0, 0) # function should not error.
reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertIsNone(get_if_single_primitive(reduce_fun, *args))
accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertIsNone(get_if_single_primitive(accum_fun, *args))
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product( @jtu.sample_product(
SCALAR_FUNCS, SCALAR_FUNCS,
@ -238,7 +296,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
for axis in range(-len(shape), len(shape))], for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1): if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
@ -251,20 +309,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product( @jtu.sample_product(
FASTPATH_FUNCS, JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis} [{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))], for axis in range(-len(shape), len(shape))]
dtype=jtu.dtypes.floating,
) )
def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): def test_ufunc_accumulate(self, name, shape, axis, dtype):
del reducer # unused jnp_fun = getattr(jnp, name)
if (nin, nout) != (2, 1): np_fun = getattr(np, name)
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),) args_maker = lambda: [rng(shape, dtype)]
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator) jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis)
def np_fun_accumulate(x):
# numpy accumulate has different dtype casting behavior.
result = np_fun.accumulate(x, axis=axis)
return result if x.dtype == bool else result.astype(x.dtype)
self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker)
self._CompileAndCheck(jnp_fun_accumulate, args_maker)
@jtu.sample_product( @jtu.sample_product(
SCALAR_FUNCS, SCALAR_FUNCS,
@ -272,7 +338,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
idx_shape=[(), (2,)], idx_shape=[(), (2,)],
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype): def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
if (nin, nout) != (2, 1): if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False) jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False)
@ -288,7 +354,31 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
def test_at_broadcasting(self): @jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
jnp_fun_at = partial(jnp_fun.at, inplace=False)
def np_fun_at(x, idx, y):
x_copy = x.copy()
np_fun.at(x_copy, idx, y)
return x_copy
self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker)
self._CompileAndCheck(jnp_fun_at, args_maker)
def test_frompyfunc_at_broadcasting(self):
# Regression test for https://github.com/google/jax/issues/18004 # Regression test for https://github.com/google/jax/issues/18004
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
np.arange(9.0).reshape(3, 3)] np.arange(9.0).reshape(3, 3)]
@ -309,7 +399,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
idx_shape=[(0,), (3,), (5,)], idx_shape=[(0,), (3,), (5,)],
dtype=jtu.dtypes.floating, dtype=jtu.dtypes.floating,
) )
def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
if (nin, nout) != (2, 1): if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis) jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis)
@ -322,6 +412,33 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
)
def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if name in ['add', 'multiply'] and dtype == bool:
# TODO(jakevdp): figure out how to fix thest cases.
self.skipTest(f"known failure for {name}.reduceat with {dtype=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
def np_fun_reduceat(x, i):
# Numpy has different casting behavior.
return np_fun.reduceat(x, i).astype(x.dtype)
self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker)
self._CompileAndCheck(jnp_fun.reduceat, args_maker)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())