First pass at ufunc interfaces for several jax.numpy functions

This commit is contained in:
Jake VanderPlas 2024-08-30 11:53:02 -07:00
parent db4be03f02
commit a3d6cf007e
7 changed files with 692 additions and 188 deletions

View File

@ -27,6 +27,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass.
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the

View File

@ -30,7 +30,6 @@ from jax._src import api
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
a_shape = a.shape
if squash_nans:
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
index[axis] = high
high_value = a[tuple(index)]
else:
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a)
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1)
@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
Array([1.5, 3. , 4.5], dtype=float32)
"""
check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0)
q, = promote_dtypes_inexact(q)
q = q / 100
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",

View File

@ -25,13 +25,11 @@ from typing import Any
import jax
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
import jax._src.numpy.lax_numpy as jnp
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
from jax._src.numpy.util import check_arraylike, _broadcast_to, _where
from jax._src.numpy.vectorize import vectorize
from jax._src.util import canonicalize_axis, set_module
from jax._src import pjit
import numpy as np
@ -42,81 +40,126 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
"""
def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None:
"""
If fun(*args) lowers to a single primitive with inputs and outputs matching
function inputs and outputs, return that primitive. Otherwise return None.
"""
try:
jaxpr = jax.make_jaxpr(fun)(*args)
except:
return None
while len(jaxpr.eqns) == 1:
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == pjit.pjit_p and
all(pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:
return jaxpr.eqns[0].primitive
return None
_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = {
lax_internal.add_p: reductions.sum,
lax_internal.mul_p: reductions.prod,
}
_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = {
lax_internal.add_p: reductions.cumsum,
lax_internal.mul_p: reductions.cumprod,
}
@set_module('jax.numpy')
class ufunc:
"""Functions that operate element-by-element on whole arrays.
"""Universal functions which operation element-by-element on arrays.
This is a class for LAX-backed implementations of numpy ufuncs.
JAX implementation of :class:`numpy.ufunc`.
This is a class for JAX-backed implementations of NumPy's ufunc APIs.
Most users will never need to instantiate :class:`ufunc`, but rather
will use the pre-defined ufuncs in :mod:`jax.numpy`.
For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`.
Examples:
Universal functions are functions that apply element-wise to broadcasted
arrays, but they also come with a number of extra attributes and methods.
As an example, consider the function :obj:`jax.numpy.add`. The object
acts as a function that applies addition to broadcasted arrays in an
element-wise manner:
>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)
Each :class:`ufunc` object includes a number of attributes that describe
its behavior:
>>> jnp.add.nin # number of inputs
2
>>> jnp.add.nout # number of outputs
1
>>> jnp.add.identity # identity value, or None if no identity exists
0
Binary ufuncs like :obj:`jax.numpy.add` include number of methods to
apply the function to arrays in different manners.
The :meth:`~ufunc.outer` method applies the function to the
pair-wise outer-product of the input array values:
>>> jnp.add.outer(x, x)
Array([[ 2, 3, 4, 5, 6],
[ 3, 4, 5, 6, 7],
[ 4, 5, 6, 7, 8],
[ 5, 6, 7, 8, 9],
[ 6, 7, 8, 9, 10]], dtype=int32)
The :meth:`ufunc.reduce` method perfoms a reduction over the array.
For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``:
>>> jnp.add.reduce(x)
Array(15, dtype=int32)
The :meth:`ufunc.accumulate` method performs a cumulative reduction
over the array. For example, :meth:`jnp.add.accumulate` is equivalent
to :func:`jax.numpy.cumulative_sum`:
>>> jnp.add.accumulate(x)
Array([ 1, 3, 6, 10, 15], dtype=int32)
The :meth:`ufunc.at` method applies the function at particular indices in the
array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`:
>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101, 2, 3, 4, 5], dtype=int32)
And the :meth:`ufunc.reduceat` method performs a number of ``reduce``
operations bewteen specified indices of an array; for ``jnp.add`` the
operation is similar to :func:`jax.ops.segment_sum`:
>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)
In this case, the first element is ``x[0:2].sum()``, and the second element
is ``x[2:].sum()``.
"""
def __init__(self, func: Callable[..., Any], /,
nin: int, nout: int, *,
name: str | None = None,
nargs: int | None = None,
identity: Any = None, update_doc=False):
identity: Any = None,
call: Callable[..., Any] | None = None,
reduce: Callable[..., Any] | None = None,
accumulate: Callable[..., Any] | None = None,
at: Callable[..., Any] | None = None,
reduceat: Callable[..., Any] | None = None,
):
self.__doc__ = func.__doc__
self.__name__ = name or func.__name__
# We want ufunc instances to work properly when marked as static,
# and for this reason it's important that their properties not be
# mutated. We prevent this by storing them in a dunder attribute,
# and accessing them via read-only properties.
if update_doc:
self.__doc__ = func.__doc__
self.__name__ = name or func.__name__
self.__static_props = {
'func': func,
'call': vectorize(func),
'nin': operator.index(nin),
'nout': operator.index(nout),
'nargs': operator.index(nargs or nin),
'identity': identity
'identity': identity,
'call': call,
'reduce': reduce,
'accumulate': accumulate,
'at': at,
'reduceat': reduceat,
}
_func = property(lambda self: self.__static_props['func'])
_call = property(lambda self: self.__static_props['call'])
nin = property(lambda self: self.__static_props['nin'])
nout = property(lambda self: self.__static_props['nout'])
nargs = property(lambda self: self.__static_props['nargs'])
identity = property(lambda self: self.__static_props['identity'])
def __hash__(self) -> int:
# Do not include _call, because it is computed from _func.
# In both __hash__ and __eq__, we do not consider call, reduce, etc.
# because they are considered implementation details rather than
# necessary parts of object identity.
return hash((self._func, self.__name__, self.identity,
self.nin, self.nout, self.nargs))
def __eq__(self, other: Any) -> bool:
# Do not include _call, because it is computed from _func.
return isinstance(other, ufunc) and (
(self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) ==
(other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs))
@ -124,20 +167,71 @@ class ufunc:
def __repr__(self) -> str:
return f"<jnp.ufunc '{self.__name__}'>"
def __call__(self, *args: ArrayLike,
out: None = None, where: None = None,
**kwargs: Any) -> Any:
def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any:
check_arraylike(self.__name__, *args)
if out is not None:
raise NotImplementedError(f"out argument of {self}")
if where is not None:
raise NotImplementedError(f"where argument of {self}")
return self._call(*args, **kwargs)
call = self.__static_props['call'] or self._call_vectorized
return call(*args)
@partial(jax.jit, static_argnames=['self'])
def _call_vectorized(self, *args):
return vectorize(self._func)(*args)
@implements(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
def reduce(self, a: ArrayLike, axis: int = 0,
dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
"""Reduction operation derived from a binary function.
JAX implementation of :meth:`numpy.ufunc.reduce`.
Args:
a: Input array.
axis: integer specifying the axis over which to reduce. default=0
dtype: optionally specify the type of the output array.
out: Unused by JAX
keepdims: If True, reduced axes are left in the result with size 1.
If False (default) then reduced axes are squeezed out.
initial: int or array, Default=None. Initial value for the reduction.
where: boolean mask, default=None. The elements to be used in the sum. Array
should be broadcast compatible to the input.
Returns:
array containing the result of the reduction operation.
Examples:
Consider the following array:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
:meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum`
along ``axis=0``:
>>> jnp.add.reduce(x)
Array([5, 7, 9], dtype=int32)
>>> x.sum(0)
Array([5, 7, 9], dtype=int32)
Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to
:func:`jax.numpy.all`:
>>> jnp.logical_and.reduce(x > 2)
Array([False, False, True], dtype=bool)
>>> jnp.all(x > 2, axis=0)
Array([False, False, True], dtype=bool)
Some reductions do not correspond to any built-in aggregation function;
for example here is the reduction of :func:`jax.numpy.bitwise_or` along
the first axis of ``x``:
>>> jnp.bitwise_or.reduce(x, axis=1)
Array([3, 7], dtype=int32)
"""
check_arraylike(f"{self.__name__}.reduce", a)
if self.nin != 2:
raise ValueError("reduce only supported for binary ufuncs")
@ -154,14 +248,10 @@ class ufunc:
"so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
if primitive is None:
reducer = self._reduce_via_scan
else:
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
reduce = self.__static_props['reduce'] or self._reduce_via_scan
return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
assert self.nin == 2 and self.nout == 1
@ -202,9 +292,9 @@ class ufunc:
def body_fun(i, val):
if where is None:
return self._call(val, arr[i].astype(dtype))
return self(val, arr[i].astype(dtype))
else:
return _where(where[i], self._call(val, arr[i].astype(dtype)), val)
return _where(where[i], self(val, arr[i].astype(dtype)), val)
start_value: ArrayLike
if initial is None:
@ -221,22 +311,63 @@ class ufunc:
result = result.reshape(final_shape)
return result
@implements(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None) -> Array:
"""Accumulate operation derived from binary ufunc.
JAX implementation of :func:`numpy.ufunc.accumulate`.
Args:
a: N-dimensional array over which to accumulate.
axis: integer axis over which accumulation will be performed (default = 0)
dtype: optionally specify the type of the output array.
out: Unused by JAX
Returns:
An array containing the accumulated result.
Examples:
Consider the following array:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
:meth:`jax.numpy.add.accumulate` is equivalent to
:func:`jax.numpy.cumsum` along the specified axis:
>>> jnp.add.accumulate(x, axis=1)
Array([[ 1, 3, 6],
[ 4, 9, 15]], dtype=int32)
>>> jnp.cumsum(x, axis=1)
Array([[ 1, 3, 6],
[ 4, 9, 15]], dtype=int32)
Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to
:func:`jax.numpy.cumprod` along the specified axis:
>>> jnp.multiply.accumulate(x, axis=1)
Array([[ 1, 2, 6],
[ 4, 20, 120]], dtype=int32)
>>> jnp.cumprod(x, axis=1)
Array([[ 1, 2, 6],
[ 4, 20, 120]], dtype=int32)
For other binary ufuncs, the accumulation is an operation not available
via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate`
is essentially a bitwise cumulative ``any``:
>>> jnp.bitwise_or.accumulate(x, axis=1)
Array([[1, 3, 3],
[4, 5, 7]], dtype=int32)
"""
if self.nin != 2:
raise ValueError("accumulate only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("accumulate only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
if primitive is None:
accumulator = self._accumulate_via_scan
else:
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
return accumulator(a, axis=axis, dtype=dtype)
accumulate = self.__static_props['accumulate'] or self._accumulate_via_scan
return accumulate(a, axis=axis, dtype=dtype)
def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0,
dtype: DTypeLike | None = None) -> Array:
@ -254,21 +385,54 @@ class ufunc:
arr = _moveaxis(arr, axis, 0)
def scan_fun(carry, _):
i, x = carry
y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype)))
y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype)))
return (i + 1, y), y
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
return _moveaxis(result, 0, axis)
@implements(np.ufunc.at, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
inplace: bool = True) -> Array:
"""Update elements of an array via the specified unary or binary ufunc.
JAX implementation of :func:`numpy.ufunc.at`.
Note:
:meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable,
so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX
will return the updated value, but requires explicitly passing ``inplace=False``
as a reminder of this difference.
Args:
a: N-dimensional array to update
indices: index, slice, or tuple of indices and slices.
b: array of values for binary ufunc updates.
inplace: must be set to False to indicate that an updated copy will be returned.
Returns:
an updated copy of the input array.
Examples:
Add numbers to specified indices:
>>> x = jnp.ones(10, dtype=int)
>>> indices = jnp.array([2, 5, 7])
>>> values = jnp.array([10, 20, 30])
>>> jnp.add.at(x, indices, values, inplace=False)
Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32)
This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method
called this way:
>>> x.at[indices].add(values)
Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32)
"""
if inplace:
raise NotImplementedError(_AT_INPLACE_WARNING)
if b is None:
return self._at_via_scan(a, indices)
else:
return self._at_via_scan(a, indices, b)
at = self.__static_props['at'] or self._at_via_scan
return at(a, indices) if b is None else at(a, indices, b)
def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array:
assert len(args) in {0, 1}
@ -276,14 +440,14 @@ class ufunc:
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype)
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
indices = _eliminate_deprecated_list_indexing(indices)
indices = jnp._eliminate_deprecated_list_indexing(indices)
if not indices:
return a
shapes = [np.shape(i) for i in indices if not isinstance(i, slice)]
shape = shapes and jax.lax.broadcast_shapes(*shapes)
if not shape:
return a.at[indices].set(self._call(a.at[indices].get(), *args))
return a.at[indices].set(self(a.at[indices].get(), *args))
if args:
arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):]))
@ -293,28 +457,65 @@ class ufunc:
def scan_fun(carry, x):
i, a = carry
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args)))
a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args)))
return (i + 1, a), x
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
return carry[1]
@implements(np.ufunc.reduceat, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Reduce an array between specified indices via a binary ufunc.
JAX implementation of :meth:`numpy.ufunc.reduceat`
Args:
a: N-dimensional array to reduce
indices: a 1-dimensional array of increasing integer values which encodes
segments of the array to be reduced.
axis: integer specifying the axis along which to reduce: default=0.
dtype: optionally specify the dtype of the output array.
out: unused by JAX
Returns:
An array containing the reduced values.
Examples:
The ``reduce`` method lets you efficiently compute reduction operations
over array segments. For example:
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
>>> indices = jnp.array([0, 2, 5])
>>> jnp.add.reduce(x, indices)
Array([ 3, 12, 21], dtype=int32)
This is more-or-less equivalent to the following:
>>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()])
Array([ 3, 12, 21], dtype=int32)
For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`.
For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`,
although in this case the segments are defined via an array of segment ids:
>>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2])
>>> jax.ops.segment_sum(x, segments)
Array([ 3, 12, 21], dtype=int32)
"""
if self.nin != 2:
raise ValueError("reduceat only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("reduceat only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.reduceat()")
return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype)
reduceat = self.__static_props['reduceat'] or self._reduceat_via_scan
return reduceat(a, indices, axis=axis, dtype=dtype)
def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0,
dtype: DTypeLike | None = None) -> Array:
check_arraylike(f"{self.__name__}.reduceat", a, indices)
a = lax_internal.asarray(a)
idx_tuple = _eliminate_deprecated_list_indexing(indices)
idx_tuple = jnp._eliminate_deprecated_list_indexing(indices)
assert len(idx_tuple) == 1
indices = idx_tuple[0]
if a.ndim == 0:
@ -326,27 +527,62 @@ class ufunc:
if axis is None or isinstance(axis, (tuple, list)):
raise ValueError("reduceat requires a single integer axis.")
axis = canonicalize_axis(axis, a.ndim)
out = take(a, indices, axis=axis)
ind = jax.lax.expand_dims(append(indices, a.shape[axis]),
out = jnp.take(a, indices, axis=axis)
ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]),
list(np.delete(np.arange(out.ndim), axis)))
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
def loop_body(i, out):
return _where((i > ind_start) & (i < ind_end),
self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
@implements(np.ufunc.outer, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0])
def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array:
def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array:
"""Apply the function to all pairs of values in ``A`` and ``B``.
JAX implementation of :meth:`numpy.ufunc.outer`.
Args:
A: N-dimensional array
B: N-dimensional array
Returns:
An array of shape `tuple(*A.shape, *B.shape)`
Examples:
A times-table for integers 1...10 created via
:meth:`jax.numpy.multiply.outer`:
>>> x = jnp.arange(1, 11)
>>> print(jnp.multiply.outer(x, x))
[[ 1 2 3 4 5 6 7 8 9 10]
[ 2 4 6 8 10 12 14 16 18 20]
[ 3 6 9 12 15 18 21 24 27 30]
[ 4 8 12 16 20 24 28 32 36 40]
[ 5 10 15 20 25 30 35 40 45 50]
[ 6 12 18 24 30 36 42 48 54 60]
[ 7 14 21 28 35 42 49 56 63 70]
[ 8 16 24 32 40 48 56 64 72 80]
[ 9 18 27 36 45 54 63 72 81 90]
[ 10 20 30 40 50 60 70 80 90 100]]
For input arrays with ``N`` and ``M`` dimensions respectively, the output
will have dimesion ``N + M``:
>>> x = jnp.ones((1, 3, 5))
>>> y = jnp.ones((2, 4))
>>> jnp.add.outer(x, y).shape
(1, 3, 5, 2, 4)
"""
if self.nin != 2:
raise ValueError("outer only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("outer only supported for functions returning a single value")
check_arraylike(f"{self.__name__}.outer", A, B)
_ravel = lambda A: jax.lax.reshape(A, (np.size(A),))
result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B))
result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B))
return result.reshape(*np.shape(A), *np.shape(B))
@ -363,4 +599,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
Returns:
wrapped : jax.numpy.ufunc wrapper of func.
"""
return ufunc(func, nin, nout, identity=identity, update_doc=True)
return ufunc(func, nin, nout, identity=identity)

View File

@ -30,11 +30,13 @@ from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax
from jax._src.lax import other as lax_other
from jax._src.typing import Array, ArrayLike
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, implements, check_no_float0s)
from jax._src.numpy.ufunc_api import ufunc
from jax._src.numpy import reductions
_lax_const = lax._const
@ -298,31 +300,81 @@ def sqrt(x: ArrayLike, /) -> Array:
def cbrt(x: ArrayLike, /) -> Array:
return lax.cbrt(*promote_args_inexact('cbrt', x))
@implements(np.add, module='numpy')
@partial(jit, inline=True)
def add(x: ArrayLike, y: ArrayLike, /) -> Array:
def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Add two arrays element-wise.
JAX implementation of :obj:`numpy.add`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: arrays to add. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise addition.
"""
x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
@implements(np.multiply, module='numpy')
@partial(jit, inline=True)
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Multiply two arrays element-wise.
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: arrays to multiply. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise multiplication.
"""
x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
@implements(np.bitwise_and, module='numpy')
@partial(jit, inline=True)
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise AND operation elementwise.
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise AND.
"""
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
@implements(np.bitwise_or, module='numpy')
@partial(jit, inline=True)
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise OR operation elementwise.
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise OR.
"""
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
@implements(np.bitwise_xor, module='numpy')
@partial(jit, inline=True)
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise XOR operation elementwise.
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise XOR.
"""
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
@implements(np.left_shift, module='numpy')
@ -376,19 +428,49 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
# Logical ops
@implements(np.logical_and, module='numpy')
@partial(jit, inline=True)
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the logical AND operation elementwise.
JAX implementation of :obj:`numpy.logical_and`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: input arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise logical AND.
"""
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
@implements(np.logical_or, module='numpy')
@partial(jit, inline=True)
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the logical OR operation elementwise.
JAX implementation of :obj:`numpy.logical_or`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: input arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise logical OR.
"""
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
@implements(np.logical_xor, module='numpy')
@partial(jit, inline=True)
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the logical XOR operation elementwise.
JAX implementation of :obj:`numpy.logical_xor`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
Args:
x, y: input arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise logical XOR.
"""
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
@implements(np.logical_not, module='numpy')
@ -1281,3 +1363,38 @@ def _sinc_maclaurin(k, x):
def _sinc_maclaurin_jvp(k, primals, tangents):
(x,), (t,) = primals, tangents
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t
def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None):
if initial is not None:
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)
def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None):
if initial is not None:
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)
# Generate ufunc interfaces for several common binary functions.
# We start with binary ufuncs that have well-defined identities.'
# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience?
# TODO(jakevdp): optimize some implementations.
# - define add.at/multiply.at in terms of scatter_add/scatter_mul
# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
# - define all monoidal reductions in terms of lax.reduce
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum)
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod)
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and)
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or)
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor)
logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce)
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)

View File

@ -965,8 +965,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
else:
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def:
self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def)
if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def:
self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def)
def test_bfloat16_constant(self):
# Re: https://github.com/google/jax/issues/3942

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import builtins
from collections.abc import Callable, Sequence
from typing import Any, Literal, NamedTuple, TypeVar, Union, overload
from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload
from jax._src import core as _core
from jax._src import dtypes as _dtypes
@ -28,6 +28,34 @@ _Device = Device
ComplexWarning: type
class BinaryUfunc(Protocol):
@property
def nin(self) -> int: ...
@property
def nout(self) -> int: ...
@property
def nargs(self) -> int: ...
@property
def identity(self) -> builtins.bool | int | float: ...
def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ...
def reduce(self, arr: ArrayLike, /, *,
axis: int | None = 0,
dtype: DTypeLike | None = None,
keepdims: builtins.bool = False,
initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
def accumulate(self, a: ArrayLike, /, *,
axis: int = 0,
dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
inplace: builtins.bool = True) -> Array: ...
def reduceat(self, a: ArrayLike, indices: Any, *,
axis: int = 0,
dtype: DTypeLike | None = None,
out: None = None) -> Array: ...
def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ...
__array_api_version__: str
def __array_namespace_info__() -> ArrayNamespaceInfo: ...
@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ...
def absolute(x: ArrayLike, /) -> Array: ...
def acos(x: ArrayLike, /) -> Array: ...
def acosh(x: ArrayLike, /) -> Array: ...
def add(x: ArrayLike, y: ArrayLike, /) -> Array: ...
add: BinaryUfunc
def amax(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ...
@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ...
bfloat16: Any
def bincount(x: ArrayLike, weights: ArrayLike | None = ...,
minlength: int = ..., *, length: int | None = ...) -> Array: ...
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
bitwise_and: BinaryUfunc
def bitwise_count(x: ArrayLike, /) -> Array: ...
def bitwise_invert(x: ArrayLike, /) -> Array: ...
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_not(x: ArrayLike, /) -> Array: ...
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
bitwise_or: BinaryUfunc
def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
bitwise_xor: BinaryUfunc
def blackman(M: int) -> Array: ...
def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ...
bool: Any
@ -251,7 +279,7 @@ def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ...) -> Array: ...
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
dtype: DTypeLike | None = ...,
include_initial: bool = ...) -> Array: ...
include_initial: builtins.bool = ...) -> Array: ...
def deg2rad(x: ArrayLike, /) -> Array: ...
degrees = rad2deg
@ -557,10 +585,10 @@ def log1p(x: ArrayLike, /) -> Array: ...
def log2(x: ArrayLike, /) -> Array: ...
def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
logical_and: BinaryUfunc
def logical_not(x: ArrayLike, /) -> Array: ...
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
logical_or: BinaryUfunc
logical_xor: BinaryUfunc
def logspace(start: ArrayLike, stop: ArrayLike, num: int = ...,
endpoint: builtins.bool = ..., base: ArrayLike = ...,
dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ...
@ -588,7 +616,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ...
def moveaxis(a: ArrayLike, source: int | Sequence[int],
destination: int | Sequence[int]) -> Array: ...
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ...
multiply: BinaryUfunc
nan: float
def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ...,
posinf: ArrayLike | None = ...,

View File

@ -14,6 +14,7 @@
"""Tests for jax.numpy.ufunc and its methods."""
import itertools
from functools import partial
from absl.testing import absltest
@ -22,7 +23,6 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.numpy.ufunc_api import get_if_single_primitive
jax.config.parse_flags_with_absl()
@ -54,18 +54,21 @@ SCALAR_FUNCS = [
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]
FASTPATH_FUNCS = [
{'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0,
'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p},
{'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1,
'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p},
def _jnp_ufunc_props(name):
jnp_func = getattr(jnp, name)
assert isinstance(jnp_func, jnp.ufunc)
np_func = getattr(np, name)
dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types]
return [dict(name=name, dtype=dtype) for dtype in dtypes]
JAX_NUMPY_UFUNCS = [
name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc)
]
NON_FASTPATH_FUNCS = [
{'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0},
{'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1},
{'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1},
]
JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS
))
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)]
@ -80,23 +83,40 @@ def cast_outputs(fun):
class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties(self, func, nin, nout, identity):
def test_frompyfunc_properties(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun.identity, identity)
self.assertEqual(jnp_fun.nin, nin)
self.assertEqual(jnp_fun.nout, nout)
self.assertEqual(jnp_fun.nargs, nin)
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties(self, name):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
self.assertEqual(jnp_fun.identity, np_fun.identity)
self.assertEqual(jnp_fun.nin, np_fun.nin)
self.assertEqual(jnp_fun.nout, np_fun.nout)
self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out`
@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties_readonly(self, func, nin, nout, identity):
def test_frompyfunc_properties_readonly(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']:
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties_readonly(self, name):
jnp_fun = getattr(jnp, name)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_hash(self, func, nin, nout, identity):
def test_frompyfunc_hash(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun, jnp_fun_2)
@ -113,7 +133,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity))
@ -123,13 +143,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
dtype=jtu.dtypes.floating,
)
def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer
@ -141,6 +176,23 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker)
self._CompileAndCheck(jnp_fun.outer, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
@ -148,7 +200,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_reduce(self, func, nin, nout, identity, shape, axis, dtype):
def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
@ -160,6 +212,26 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_ufunc_reduce(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
np_fun_reduce = partial(np_fun.reduce, axis=axis)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
@ -167,7 +239,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
@ -194,42 +266,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
FASTPATH_FUNCS,
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
del accumulator # unused
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
def test_ufunc_reduce_where(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if jnp_fun.identity is None:
self.skipTest("reduce with where requires identity")
jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where)
np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where)
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer)
@jtu.sample_product(
NON_FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
_ = func(0, 0) # function should not error.
reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertIsNone(get_if_single_primitive(reduce_fun, *args))
accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertIsNone(get_if_single_primitive(accum_fun, *args))
rng_where = jtu.rand_bool(self.rng())
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
@ -238,7 +296,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
@ -251,20 +309,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
FASTPATH_FUNCS,
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
for axis in range(-len(shape), len(shape))]
)
def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
del reducer # unused
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
def test_ufunc_accumulate(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator)
args_maker = lambda: [rng(shape, dtype)]
jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis)
def np_fun_accumulate(x):
# numpy accumulate has different dtype casting behavior.
result = np_fun.accumulate(x, axis=axis)
return result if x.dtype == bool else result.astype(x.dtype)
self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker)
self._CompileAndCheck(jnp_fun_accumulate, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
@ -272,7 +338,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
idx_shape=[(), (2,)],
dtype=jtu.dtypes.floating,
)
def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False)
@ -288,7 +354,31 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def test_at_broadcasting(self):
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
jnp_fun_at = partial(jnp_fun.at, inplace=False)
def np_fun_at(x, idx, y):
x_copy = x.copy()
np_fun.at(x_copy, idx, y)
return x_copy
self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker)
self._CompileAndCheck(jnp_fun_at, args_maker)
def test_frompyfunc_at_broadcasting(self):
# Regression test for https://github.com/google/jax/issues/18004
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
np.arange(9.0).reshape(3, 3)]
@ -309,7 +399,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
idx_shape=[(0,), (3,), (5,)],
dtype=jtu.dtypes.floating,
)
def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis)
@ -322,6 +412,33 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
)
def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if name in ['add', 'multiply'] and dtype == bool:
# TODO(jakevdp): figure out how to fix thest cases.
self.skipTest(f"known failure for {name}.reduceat with {dtype=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
def np_fun_reduceat(x, i):
# Numpy has different casting behavior.
return np_fun.reduceat(x, i).astype(x.dtype)
self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker)
self._CompileAndCheck(jnp_fun.reduceat, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())