Make jnp.subtract a ufunc

This commit is contained in:
Jake VanderPlas 2024-10-21 10:11:51 -07:00
parent ad53addb74
commit 6467d03925
5 changed files with 39 additions and 4 deletions

View File

@ -284,6 +284,8 @@ class _IndexUpdateRef:
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
def add(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def subtract(self, values: Any, *, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def mul(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: str | None = None) -> Array: ...
def multiply(self, values: Any, indices_are_sorted: bool = False,

View File

@ -1787,7 +1787,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
slice1_tuple = tuple(slice1)
slice2_tuple = tuple(slice2)
op = ufuncs.not_equal if arr.dtype == np.bool_ else ufuncs.subtract
op = operator.not_equal if arr.dtype == np.bool_ else operator.sub
for _ in range(n):
arr = op(arr[slice1_tuple], arr[slice2_tuple])

View File

@ -1432,11 +1432,37 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.ne(*promote_args("not_equal", x, y))
@implements(np.subtract, module='numpy')
@partial(jit, inline=True)
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Subtract two arrays element-wise.
JAX implementation of :obj:`numpy.subtract`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``-`` operator for
JAX arrays.
Args:
x, y: arrays to subtract. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise subtraction.
Examples:
Calling ``subtract`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.subtract(x, 10)
Array([-10, -9, -8, -7], dtype=int32)
Calling ``subtract`` via the ``-`` operator:
>>> x - 10
Array([-10, -9, -8, -7], dtype=int32)
"""
return lax.sub(*promote_args("subtract", x, y))
@implements(np.arctan2, module='numpy')
@partial(jit, inline=True)
def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@ -3604,6 +3630,9 @@ def _add_at(a: Array, indices: Any, b: ArrayLike):
return a.at[indices].add(b).astype(bool)
return a.at[indices].add(b)
def _subtract_at(a: Array, indices: Any, b: ArrayLike):
return a.at[indices].subtract(b)
def _multiply_at(a: Array, indices: Any, b: ArrayLike):
if a.dtype == bool:
a = a.astype('int32')
@ -3628,3 +3657,4 @@ logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=Tr
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)
negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative)
subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at)

View File

@ -829,7 +829,7 @@ def stack(
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ...
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
subtract: BinaryUfunc
def sum(
a: ArrayLike,
axis: _Axis = ...,

View File

@ -250,6 +250,9 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if jnp_fun.identity is None and axis is None and len(shape) > 1:
self.skipTest("Multiple-axis reduction over non-reorderable ufunc.")
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
np_fun_reduce = partial(np_fun.reduce, axis=axis)