diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a3dac8c3..f59b07cd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * ``jax.tree_util.register_dataclass`` now checks that ``data_fields`` and ``meta_fields`` includes all dataclass fields with ``init=True`` and only them, if ``nodetype`` is a dataclass. + * Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc` + interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`, + {obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`, + {obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`, + {obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`. + In future releases we plan to expand these to other ufuncs. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index e8815c943..dddb44dc9 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,6 @@ from jax._src import api from jax._src import core from jax._src import deprecations from jax._src import dtypes -from jax._src.numpy import ufuncs from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) @@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape if squash_nans: - a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. + a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) - counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) + counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) @@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) @@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, Array([1.5, 3. , 4.5], dtype=float32) """ check_arraylike("nanpercentile", a, q) - q = ufuncs.true_divide(q, 100.0) + q, = promote_dtypes_inexact(q) + q = q / 100 if not isinstance(interpolation, DeprecatedArg): deprecations.warn( "jax-numpy-quantile-interpolation", diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2e114193a..3473e8a74 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -25,13 +25,11 @@ from typing import Any import jax from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.lax import lax as lax_internal -from jax._src.numpy import reductions -from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take +import jax._src.numpy.lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis -from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where +from jax._src.numpy.util import check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module -from jax._src import pjit import numpy as np @@ -42,81 +40,126 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. """ -def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None: - """ - If fun(*args) lowers to a single primitive with inputs and outputs matching - function inputs and outputs, return that primitive. Otherwise return None. - """ - try: - jaxpr = jax.make_jaxpr(fun)(*args) - except: - return None - while len(jaxpr.eqns) == 1: - eqn = jaxpr.eqns[0] - if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars): - return None - elif (eqn.primitive == pjit.pjit_p and - all(pjit.is_unspecified(sharding) for sharding in - (*eqn.params['in_shardings'], *eqn.params['out_shardings']))): - jaxpr = jaxpr.eqns[0].params['jaxpr'] - else: - return jaxpr.eqns[0].primitive - return None - - -_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.sum, - lax_internal.mul_p: reductions.prod, -} - - -_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.cumsum, - lax_internal.mul_p: reductions.cumprod, -} - - @set_module('jax.numpy') class ufunc: - """Functions that operate element-by-element on whole arrays. + """Universal functions which operation element-by-element on arrays. - This is a class for LAX-backed implementations of numpy ufuncs. + JAX implementation of :class:`numpy.ufunc`. + + This is a class for JAX-backed implementations of NumPy's ufunc APIs. + Most users will never need to instantiate :class:`ufunc`, but rather + will use the pre-defined ufuncs in :mod:`jax.numpy`. + + For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`. + + Examples: + Universal functions are functions that apply element-wise to broadcasted + arrays, but they also come with a number of extra attributes and methods. + + As an example, consider the function :obj:`jax.numpy.add`. The object + acts as a function that applies addition to broadcasted arrays in an + element-wise manner: + + >>> x = jnp.array([1, 2, 3, 4, 5]) + >>> jnp.add(x, 1) + Array([2, 3, 4, 5, 6], dtype=int32) + + Each :class:`ufunc` object includes a number of attributes that describe + its behavior: + + >>> jnp.add.nin # number of inputs + 2 + >>> jnp.add.nout # number of outputs + 1 + >>> jnp.add.identity # identity value, or None if no identity exists + 0 + + Binary ufuncs like :obj:`jax.numpy.add` include number of methods to + apply the function to arrays in different manners. + + The :meth:`~ufunc.outer` method applies the function to the + pair-wise outer-product of the input array values: + + >>> jnp.add.outer(x, x) + Array([[ 2, 3, 4, 5, 6], + [ 3, 4, 5, 6, 7], + [ 4, 5, 6, 7, 8], + [ 5, 6, 7, 8, 9], + [ 6, 7, 8, 9, 10]], dtype=int32) + + The :meth:`ufunc.reduce` method perfoms a reduction over the array. + For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``: + + >>> jnp.add.reduce(x) + Array(15, dtype=int32) + + The :meth:`ufunc.accumulate` method performs a cumulative reduction + over the array. For example, :meth:`jnp.add.accumulate` is equivalent + to :func:`jax.numpy.cumulative_sum`: + + >>> jnp.add.accumulate(x) + Array([ 1, 3, 6, 10, 15], dtype=int32) + + The :meth:`ufunc.at` method applies the function at particular indices in the + array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`: + + >>> jnp.add.at(x, 0, 100, inplace=False) + Array([101, 2, 3, 4, 5], dtype=int32) + + And the :meth:`ufunc.reduceat` method performs a number of ``reduce`` + operations bewteen specified indices of an array; for ``jnp.add`` the + operation is similar to :func:`jax.ops.segment_sum`: + + >>> jnp.add.reduceat(x, jnp.array([0, 2])) + Array([ 3, 12], dtype=int32) + + In this case, the first element is ``x[0:2].sum()``, and the second element + is ``x[2:].sum()``. """ def __init__(self, func: Callable[..., Any], /, nin: int, nout: int, *, name: str | None = None, nargs: int | None = None, - identity: Any = None, update_doc=False): + identity: Any = None, + call: Callable[..., Any] | None = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None, + ): + self.__doc__ = func.__doc__ + self.__name__ = name or func.__name__ # We want ufunc instances to work properly when marked as static, # and for this reason it's important that their properties not be # mutated. We prevent this by storing them in a dunder attribute, # and accessing them via read-only properties. - if update_doc: - self.__doc__ = func.__doc__ - self.__name__ = name or func.__name__ self.__static_props = { 'func': func, - 'call': vectorize(func), 'nin': operator.index(nin), 'nout': operator.index(nout), 'nargs': operator.index(nargs or nin), - 'identity': identity + 'identity': identity, + 'call': call, + 'reduce': reduce, + 'accumulate': accumulate, + 'at': at, + 'reduceat': reduceat, } _func = property(lambda self: self.__static_props['func']) - _call = property(lambda self: self.__static_props['call']) nin = property(lambda self: self.__static_props['nin']) nout = property(lambda self: self.__static_props['nout']) nargs = property(lambda self: self.__static_props['nargs']) identity = property(lambda self: self.__static_props['identity']) def __hash__(self) -> int: - # Do not include _call, because it is computed from _func. + # In both __hash__ and __eq__, we do not consider call, reduce, etc. + # because they are considered implementation details rather than + # necessary parts of object identity. return hash((self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs)) def __eq__(self, other: Any) -> bool: - # Do not include _call, because it is computed from _func. return isinstance(other, ufunc) and ( (self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) == (other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs)) @@ -124,20 +167,71 @@ class ufunc: def __repr__(self) -> str: return f"" - def __call__(self, *args: ArrayLike, - out: None = None, where: None = None, - **kwargs: Any) -> Any: + def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any: + check_arraylike(self.__name__, *args) if out is not None: raise NotImplementedError(f"out argument of {self}") if where is not None: raise NotImplementedError(f"where argument of {self}") - return self._call(*args, **kwargs) + call = self.__static_props['call'] or self._call_vectorized + return call(*args) + + @partial(jax.jit, static_argnames=['self']) + def _call_vectorized(self, *args): + return vectorize(self._func)(*args) - @implements(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def reduce(self, a: ArrayLike, axis: int = 0, + dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Reduction operation derived from a binary function. + + JAX implementation of :meth:`numpy.ufunc.reduce`. + + Args: + a: Input array. + axis: integer specifying the axis over which to reduce. default=0 + dtype: optionally specify the type of the output array. + out: Unused by JAX + keepdims: If True, reduced axes are left in the result with size 1. + If False (default) then reduced axes are squeezed out. + initial: int or array, Default=None. Initial value for the reduction. + where: boolean mask, default=None. The elements to be used in the sum. Array + should be broadcast compatible to the input. + + Returns: + array containing the result of the reduction operation. + + Examples: + Consider the following array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + :meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum` + along ``axis=0``: + + >>> jnp.add.reduce(x) + Array([5, 7, 9], dtype=int32) + >>> x.sum(0) + Array([5, 7, 9], dtype=int32) + + Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to + :func:`jax.numpy.all`: + + >>> jnp.logical_and.reduce(x > 2) + Array([False, False, True], dtype=bool) + >>> jnp.all(x > 2, axis=0) + Array([False, False, True], dtype=bool) + + Some reductions do not correspond to any built-in aggregation function; + for example here is the reduction of :func:`jax.numpy.bitwise_or` along + the first axis of ``x``: + + >>> jnp.bitwise_or.reduce(x, axis=1) + Array([3, 7], dtype=int32) + """ check_arraylike(f"{self.__name__}.reduce", a) if self.nin != 2: raise ValueError("reduce only supported for binary ufuncs") @@ -154,14 +248,10 @@ class ufunc: "so to use a where mask one has to specify 'initial'.") if lax_internal._dtype(where) != bool: raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - reducer = self._reduce_via_scan - else: - reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) - return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + reduce = self.__static_props['reduce'] or self._reduce_via_scan + return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 @@ -202,9 +292,9 @@ class ufunc: def body_fun(i, val): if where is None: - return self._call(val, arr[i].astype(dtype)) + return self(val, arr[i].astype(dtype)) else: - return _where(where[i], self._call(val, arr[i].astype(dtype)), val) + return _where(where[i], self(val, arr[i].astype(dtype)), val) start_value: ArrayLike if initial is None: @@ -221,22 +311,63 @@ class ufunc: result = result.reshape(final_shape) return result - @implements(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Accumulate operation derived from binary ufunc. + + JAX implementation of :func:`numpy.ufunc.accumulate`. + + Args: + a: N-dimensional array over which to accumulate. + axis: integer axis over which accumulation will be performed (default = 0) + dtype: optionally specify the type of the output array. + out: Unused by JAX + + Returns: + An array containing the accumulated result. + + Examples: + Consider the following array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + :meth:`jax.numpy.add.accumulate` is equivalent to + :func:`jax.numpy.cumsum` along the specified axis: + >>> jnp.add.accumulate(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + >>> jnp.cumsum(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + + Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to + :func:`jax.numpy.cumprod` along the specified axis: + + >>> jnp.multiply.accumulate(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + >>> jnp.cumprod(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + + For other binary ufuncs, the accumulation is an operation not available + via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate` + is essentially a bitwise cumulative ``any``: + + >>> jnp.bitwise_or.accumulate(x, axis=1) + Array([[1, 3, 3], + [4, 5, 7]], dtype=int32) + """ if self.nin != 2: raise ValueError("accumulate only supported for binary ufuncs") if self.nout != 1: raise ValueError("accumulate only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - accumulator = self._accumulate_via_scan - else: - accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) - return accumulator(a, axis=axis, dtype=dtype) + accumulate = self.__static_props['accumulate'] or self._accumulate_via_scan + return accumulate(a, axis=axis, dtype=dtype) def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None) -> Array: @@ -254,21 +385,54 @@ class ufunc: arr = _moveaxis(arr, axis, 0) def scan_fun(carry, _): i, x = carry - y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype))) + y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype))) return (i + 1, y), y _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @implements(np.ufunc.at, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: + """Update elements of an array via the specified unary or binary ufunc. + + JAX implementation of :func:`numpy.ufunc.at`. + + Note: + :meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable, + so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX + will return the updated value, but requires explicitly passing ``inplace=False`` + as a reminder of this difference. + + Args: + a: N-dimensional array to update + indices: index, slice, or tuple of indices and slices. + b: array of values for binary ufunc updates. + inplace: must be set to False to indicate that an updated copy will be returned. + + Returns: + an updated copy of the input array. + + Examples: + + Add numbers to specified indices: + + >>> x = jnp.ones(10, dtype=int) + >>> indices = jnp.array([2, 5, 7]) + >>> values = jnp.array([10, 20, 30]) + >>> jnp.add.at(x, indices, values, inplace=False) + Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32) + + This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method + called this way: + + >>> x.at[indices].add(values) + Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32) + """ if inplace: raise NotImplementedError(_AT_INPLACE_WARNING) - if b is None: - return self._at_via_scan(a, indices) - else: - return self._at_via_scan(a, indices, b) + + at = self.__static_props['at'] or self._at_via_scan + return at(a, indices) if b is None else at(a, indices, b) def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: assert len(args) in {0, 1} @@ -276,14 +440,14 @@ class ufunc: dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) - indices = _eliminate_deprecated_list_indexing(indices) + indices = jnp._eliminate_deprecated_list_indexing(indices) if not indices: return a shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] shape = shapes and jax.lax.broadcast_shapes(*shapes) if not shape: - return a.at[indices].set(self._call(a.at[indices].get(), *args)) + return a.at[indices].set(self(a.at[indices].get(), *args)) if args: arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):])) @@ -293,28 +457,65 @@ class ufunc: def scan_fun(carry, x): i, a = carry idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) - a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args))) + a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) return carry[1] - @implements(np.ufunc.reduceat, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Reduce an array between specified indices via a binary ufunc. + + JAX implementation of :meth:`numpy.ufunc.reduceat` + + Args: + a: N-dimensional array to reduce + indices: a 1-dimensional array of increasing integer values which encodes + segments of the array to be reduced. + axis: integer specifying the axis along which to reduce: default=0. + dtype: optionally specify the dtype of the output array. + out: unused by JAX + Returns: + An array containing the reduced values. + + Examples: + The ``reduce`` method lets you efficiently compute reduction operations + over array segments. For example: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) + >>> indices = jnp.array([0, 2, 5]) + >>> jnp.add.reduce(x, indices) + Array([ 3, 12, 21], dtype=int32) + + This is more-or-less equivalent to the following: + + >>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()]) + Array([ 3, 12, 21], dtype=int32) + + For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`. + For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`, + although in this case the segments are defined via an array of segment ids: + + >>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2]) + >>> jax.ops.segment_sum(x, segments) + Array([ 3, 12, 21], dtype=int32) + """ if self.nin != 2: raise ValueError("reduceat only supported for binary ufuncs") if self.nout != 1: raise ValueError("reduceat only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.reduceat()") - return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype) + + reduceat = self.__static_props['reduceat'] or self._reduceat_via_scan + return reduceat(a, indices, axis=axis, dtype=dtype) def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) a = lax_internal.asarray(a) - idx_tuple = _eliminate_deprecated_list_indexing(indices) + idx_tuple = jnp._eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] if a.ndim == 0: @@ -326,27 +527,62 @@ class ufunc: if axis is None or isinstance(axis, (tuple, list)): raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) - out = take(a, indices, axis=axis) - ind = jax.lax.expand_dims(append(indices, a.shape[axis]), + out = jnp.take(a, indices, axis=axis) + ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), list(np.delete(np.arange(out.ndim), axis))) ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) - @implements(np.ufunc.outer, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0]) - def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array: + def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: + """Apply the function to all pairs of values in ``A`` and ``B``. + + JAX implementation of :meth:`numpy.ufunc.outer`. + + Args: + A: N-dimensional array + B: N-dimensional array + + Returns: + An array of shape `tuple(*A.shape, *B.shape)` + + Examples: + A times-table for integers 1...10 created via + :meth:`jax.numpy.multiply.outer`: + + >>> x = jnp.arange(1, 11) + >>> print(jnp.multiply.outer(x, x)) + [[ 1 2 3 4 5 6 7 8 9 10] + [ 2 4 6 8 10 12 14 16 18 20] + [ 3 6 9 12 15 18 21 24 27 30] + [ 4 8 12 16 20 24 28 32 36 40] + [ 5 10 15 20 25 30 35 40 45 50] + [ 6 12 18 24 30 36 42 48 54 60] + [ 7 14 21 28 35 42 49 56 63 70] + [ 8 16 24 32 40 48 56 64 72 80] + [ 9 18 27 36 45 54 63 72 81 90] + [ 10 20 30 40 50 60 70 80 90 100]] + + For input arrays with ``N`` and ``M`` dimensions respectively, the output + will have dimesion ``N + M``: + + >>> x = jnp.ones((1, 3, 5)) + >>> y = jnp.ones((2, 4)) + >>> jnp.add.outer(x, y).shape + (1, 3, 5, 2, 4) + """ if self.nin != 2: raise ValueError("outer only supported for binary ufuncs") if self.nout != 1: raise ValueError("outer only supported for functions returning a single value") check_arraylike(f"{self.__name__}.outer", A, B) _ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) - result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B)) + result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) return result.reshape(*np.shape(A), *np.shape(B)) @@ -363,4 +599,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. """ - return ufunc(func, nin, nout, identity=identity, update_doc=True) + return ufunc(func, nin, nout, identity=identity) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dfeff38df..aa8ac4e95 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -30,11 +30,13 @@ from jax._src.api import jit from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.lax import other as lax_other -from jax._src.typing import Array, ArrayLike +from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, implements, check_no_float0s) +from jax._src.numpy.ufunc_api import ufunc +from jax._src.numpy import reductions _lax_const = lax._const @@ -298,31 +300,81 @@ def sqrt(x: ArrayLike, /) -> Array: def cbrt(x: ArrayLike, /) -> Array: return lax.cbrt(*promote_args_inexact('cbrt', x)) -@implements(np.add, module='numpy') @partial(jit, inline=True) -def add(x: ArrayLike, y: ArrayLike, /) -> Array: +def _add(x: ArrayLike, y: ArrayLike, /) -> Array: + """Add two arrays element-wise. + + JAX implementation of :obj:`numpy.add`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: arrays to add. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise addition. + """ x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) -@implements(np.multiply, module='numpy') @partial(jit, inline=True) -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: +def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: + """Multiply two arrays element-wise. + + JAX implementation of :obj:`numpy.multiply`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: arrays to multiply. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise multiplication. + """ x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@implements(np.bitwise_and, module='numpy') @partial(jit, inline=True) -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise AND operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise AND. + """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@implements(np.bitwise_or, module='numpy') @partial(jit, inline=True) -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise OR operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise OR. + """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@implements(np.bitwise_xor, module='numpy') @partial(jit, inline=True) -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise XOR operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise XOR. + """ return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) @implements(np.left_shift, module='numpy') @@ -376,19 +428,49 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.nextafter(*promote_args_inexact("nextafter", x, y)) # Logical ops -@implements(np.logical_and, module='numpy') @partial(jit, inline=True) -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical AND operation elementwise. + + JAX implementation of :obj:`numpy.logical_and`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical AND. + """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) -@implements(np.logical_or, module='numpy') @partial(jit, inline=True) -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical OR operation elementwise. + + JAX implementation of :obj:`numpy.logical_or`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical OR. + """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) -@implements(np.logical_xor, module='numpy') @partial(jit, inline=True) -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical XOR operation elementwise. + + JAX implementation of :obj:`numpy.logical_xor`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical XOR. + """ return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) @implements(np.logical_not, module='numpy') @@ -1281,3 +1363,38 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t + + +def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_and.reduce()") + result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_or.reduce()") + result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +# Generate ufunc interfaces for several common binary functions. +# We start with binary ufuncs that have well-defined identities.' +# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience? +# TODO(jakevdp): optimize some implementations. +# - define add.at/multiply.at in terms of scatter_add/scatter_mul +# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod +# - define all monoidal reductions in terms of lax.reduce +add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum) +multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod) +bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) +bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) +bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) +logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) +logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) +logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ffa3a103e..64d461fe9 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -965,8 +965,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) + if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) def test_bfloat16_constant(self): # Re: https://github.com/google/jax/issues/3942 diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 583f6886e..5e2c1dce4 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -3,7 +3,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence -from typing import Any, Literal, NamedTuple, TypeVar, Union, overload +from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -28,6 +28,34 @@ _Device = Device ComplexWarning: type +class BinaryUfunc(Protocol): + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... + def reduce(self, arr: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ... + __array_api_version__: str def __array_namespace_info__() -> ArrayNamespaceInfo: ... @@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ... def absolute(x: ArrayLike, /) -> Array: ... def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... -def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... +add: BinaryUfunc def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... @@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ... bfloat16: Any def bincount(x: ArrayLike, weights: ArrayLike | None = ..., minlength: int = ..., *, length: int | None = ...) -> Array: ... -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_and: BinaryUfunc def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ... -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_or: BinaryUfunc def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_xor: BinaryUfunc def blackman(M: int) -> Array: ... def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any @@ -251,7 +279,7 @@ def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., - include_initial: bool = ...) -> Array: ... + include_initial: builtins.bool = ...) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg @@ -557,10 +585,10 @@ def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_or: BinaryUfunc +logical_xor: BinaryUfunc def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... @@ -588,7 +616,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: ... -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... +multiply: BinaryUfunc nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., posinf: ArrayLike | None = ..., diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 16eb9321c..c65df8aa8 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -14,6 +14,7 @@ """Tests for jax.numpy.ufunc and its methods.""" +import itertools from functools import partial from absl.testing import absltest @@ -22,7 +23,6 @@ import numpy as np import jax import jax.numpy as jnp from jax._src import test_util as jtu -from jax._src.numpy.ufunc_api import get_if_single_primitive jax.config.parse_flags_with_absl() @@ -54,18 +54,21 @@ SCALAR_FUNCS = [ {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, ] -FASTPATH_FUNCS = [ - {'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0, - 'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p}, - {'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1, - 'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p}, +def _jnp_ufunc_props(name): + jnp_func = getattr(jnp, name) + assert isinstance(jnp_func, jnp.ufunc) + np_func = getattr(np, name) + dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types] + return [dict(name=name, dtype=dtype) for dtype in dtypes] + + +JAX_NUMPY_UFUNCS = [ + name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc) ] -NON_FASTPATH_FUNCS = [ - {'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0}, - {'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1}, - {'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1}, -] +JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS +)) broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] @@ -80,23 +83,40 @@ def cast_outputs(fun): class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties(self, func, nin, nout, identity): + def test_frompyfunc_properties(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun.identity, identity) self.assertEqual(jnp_fun.nin, nin) self.assertEqual(jnp_fun.nout, nout) self.assertEqual(jnp_fun.nargs, nin) + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties(self, name): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + self.assertEqual(jnp_fun.identity, np_fun.identity) + self.assertEqual(jnp_fun.nin, np_fun.nin) + self.assertEqual(jnp_fun.nout, np_fun.nout) + self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out` + @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties_readonly(self, func, nin, nout, identity): + def test_frompyfunc_properties_readonly(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) - for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']: + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: + getattr(jnp_fun, attr) # no error on attribute access. + with self.assertRaises(AttributeError): + setattr(jnp_fun, attr, None) # error when trying to mutate. + + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties_readonly(self, name): + jnp_fun = getattr(jnp, name) + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: getattr(jnp_fun, attr) # no error on attribute access. with self.assertRaises(AttributeError): setattr(jnp_fun, attr, None) # error when trying to mutate. @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_hash(self, func, nin, nout, identity): + def test_frompyfunc_hash(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun, jnp_fun_2) @@ -113,7 +133,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): dtype=jtu.dtypes.floating, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity)) @@ -123,13 +143,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( SCALAR_FUNCS, lhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes, dtype=jtu.dtypes.floating, ) - def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer @@ -141,6 +176,23 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + self._CompileAndCheck(jnp_fun.outer, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -148,7 +200,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) @@ -160,6 +212,26 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_ufunc_reduce(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis) + np_fun_reduce = partial(np_fun.reduce, axis=axis) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -167,7 +239,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") @@ -194,42 +266,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, + JAX_NUMPY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in [None, *range(-len(shape), len(shape))]], ) - def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del accumulator # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + def test_ufunc_reduce_where(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if jnp_fun.identity is None: + self.skipTest("reduce with where requires identity") + + jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where) + np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where) + rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer) - - @jtu.sample_product( - NON_FASTPATH_FUNCS, - [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, - ) - def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - - _ = func(0, 0) # function should not error. - - reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertIsNone(get_if_single_primitive(reduce_fun, *args)) - - accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertIsNone(get_if_single_primitive(accum_fun, *args)) + rng_where = jtu.rand_bool(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -238,7 +296,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): for axis in range(-len(shape), len(shape))], dtype=jtu.dtypes.floating, ) - def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) @@ -251,20 +309,28 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, + JAX_NUMPY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in range(-len(shape), len(shape))] ) - def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del reducer # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + def test_ufunc_accumulate(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator) + args_maker = lambda: [rng(shape, dtype)] + + jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis) + def np_fun_accumulate(x): + # numpy accumulate has different dtype casting behavior. + result = np_fun.accumulate(x, axis=axis) + return result if x.dtype == bool else result.astype(x.dtype) + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) + self._CompileAndCheck(jnp_fun_accumulate, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -272,7 +338,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): idx_shape=[(), (2,)], dtype=jtu.dtypes.floating, ) - def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype): + def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False) @@ -288,7 +354,31 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def test_at_broadcasting(self): + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + shape=nonscalar_shapes, + idx_shape=[(), (2,)], + ) + def test_ufunc_at(self, name, shape, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)] + + jnp_fun_at = partial(jnp_fun.at, inplace=False) + def np_fun_at(x, idx, y): + x_copy = x.copy() + np_fun.at(x_copy, idx, y) + return x_copy + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + self._CompileAndCheck(jnp_fun_at, args_maker) + + def test_frompyfunc_at_broadcasting(self): # Regression test for https://github.com/google/jax/issues/18004 args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), np.arange(9.0).reshape(3, 3)] @@ -309,7 +399,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): idx_shape=[(0,), (3,), (5,)], dtype=jtu.dtypes.floating, ) - def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): + def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis) @@ -322,6 +412,33 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [*range(-len(shape), len(shape))]], + idx_shape=[(0,), (3,), (5,)], + ) + def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if name in ['add', 'multiply'] and dtype == bool: + # TODO(jakevdp): figure out how to fix thest cases. + self.skipTest(f"known failure for {name}.reduceat with {dtype=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + + def np_fun_reduceat(x, i): + # Numpy has different casting behavior. + return np_fun.reduceat(x, i).astype(x.dtype) + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker) + self._CompileAndCheck(jnp_fun.reduceat, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())