Merge pull request #23739 from jakevdp:add-at

PiperOrigin-RevId: 677924623
This commit is contained in:
jax authors 2024-09-23 13:47:27 -07:00
commit 46867dc495

View File

@ -21,6 +21,7 @@ from __future__ import annotations
from collections.abc import Callable
from functools import partial
import operator
from typing import Any
import numpy as np
@ -2584,6 +2585,20 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)
def _add_at(a: Array, indices: Any, b: ArrayLike):
if a.dtype == bool:
a = a.astype('int32')
b = lax.convert_element_type(b, bool).astype('int32')
return a.at[indices].add(b).astype(bool)
return a.at[indices].add(b)
def _multiply_at(a: Array, indices: Any, b: ArrayLike):
if a.dtype == bool:
a = a.astype('int32')
b = lax.convert_element_type(b, bool).astype('int32')
return a.at[indices].mul(b).astype(bool)
else:
return a.at[indices].mul(b)
# Generate ufunc interfaces for several common binary functions.
# We start with binary ufuncs that have well-defined identities.'
@ -2592,8 +2607,8 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No
# - define add.at/multiply.at in terms of scatter_add/scatter_mul
# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
# - define all monoidal reductions in terms of lax.reduce
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum)
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod)
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at)
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at)
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and)
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or)
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor)