mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Make jnp.subtract a ufunc
This commit is contained in:
parent
ad53addb74
commit
6467d03925
@ -284,6 +284,8 @@ class _IndexUpdateRef:
|
||||
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
|
||||
def add(self, values: Any, indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False, mode: str | None = None) -> Array: ...
|
||||
def subtract(self, values: Any, *, indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False, mode: str | None = None) -> Array: ...
|
||||
def mul(self, values: Any, indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False, mode: str | None = None) -> Array: ...
|
||||
def multiply(self, values: Any, indices_are_sorted: bool = False,
|
||||
|
@ -1787,7 +1787,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
slice1_tuple = tuple(slice1)
|
||||
slice2_tuple = tuple(slice2)
|
||||
|
||||
op = ufuncs.not_equal if arr.dtype == np.bool_ else ufuncs.subtract
|
||||
op = operator.not_equal if arr.dtype == np.bool_ else operator.sub
|
||||
for _ in range(n):
|
||||
arr = op(arr[slice1_tuple], arr[slice2_tuple])
|
||||
|
||||
|
@ -1432,11 +1432,37 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""
|
||||
return lax.ne(*promote_args("not_equal", x, y))
|
||||
|
||||
@implements(np.subtract, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""Subtract two arrays element-wise.
|
||||
|
||||
JAX implementation of :obj:`numpy.subtract`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
This function provides the implementation of the ``-`` operator for
|
||||
JAX arrays.
|
||||
|
||||
Args:
|
||||
x, y: arrays to subtract. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise subtraction.
|
||||
|
||||
Examples:
|
||||
Calling ``subtract`` explicitly:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.subtract(x, 10)
|
||||
Array([-10, -9, -8, -7], dtype=int32)
|
||||
|
||||
Calling ``subtract`` via the ``-`` operator:
|
||||
|
||||
>>> x - 10
|
||||
Array([-10, -9, -8, -7], dtype=int32)
|
||||
"""
|
||||
return lax.sub(*promote_args("subtract", x, y))
|
||||
|
||||
|
||||
@implements(np.arctan2, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
@ -3604,6 +3630,9 @@ def _add_at(a: Array, indices: Any, b: ArrayLike):
|
||||
return a.at[indices].add(b).astype(bool)
|
||||
return a.at[indices].add(b)
|
||||
|
||||
def _subtract_at(a: Array, indices: Any, b: ArrayLike):
|
||||
return a.at[indices].subtract(b)
|
||||
|
||||
def _multiply_at(a: Array, indices: Any, b: ArrayLike):
|
||||
if a.dtype == bool:
|
||||
a = a.astype('int32')
|
||||
@ -3628,3 +3657,4 @@ logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=Tr
|
||||
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
|
||||
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)
|
||||
negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative)
|
||||
subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at)
|
||||
|
@ -829,7 +829,7 @@ def stack(
|
||||
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
||||
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
|
||||
where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ...
|
||||
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
subtract: BinaryUfunc
|
||||
def sum(
|
||||
a: ArrayLike,
|
||||
axis: _Axis = ...,
|
||||
|
@ -250,6 +250,9 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
|
||||
if jnp_fun.identity is None and axis is None and len(shape) > 1:
|
||||
self.skipTest("Multiple-axis reduction over non-reorderable ufunc.")
|
||||
|
||||
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
|
||||
np_fun_reduce = partial(np_fun.reduce, axis=axis)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user