mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23739 from jakevdp:add-at
PiperOrigin-RevId: 677924623
This commit is contained in:
commit
46867dc495
@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -2584,6 +2585,20 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No
|
||||
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
|
||||
return result if dtype is None else result.astype(dtype)
|
||||
|
||||
def _add_at(a: Array, indices: Any, b: ArrayLike):
|
||||
if a.dtype == bool:
|
||||
a = a.astype('int32')
|
||||
b = lax.convert_element_type(b, bool).astype('int32')
|
||||
return a.at[indices].add(b).astype(bool)
|
||||
return a.at[indices].add(b)
|
||||
|
||||
def _multiply_at(a: Array, indices: Any, b: ArrayLike):
|
||||
if a.dtype == bool:
|
||||
a = a.astype('int32')
|
||||
b = lax.convert_element_type(b, bool).astype('int32')
|
||||
return a.at[indices].mul(b).astype(bool)
|
||||
else:
|
||||
return a.at[indices].mul(b)
|
||||
|
||||
# Generate ufunc interfaces for several common binary functions.
|
||||
# We start with binary ufuncs that have well-defined identities.'
|
||||
@ -2592,8 +2607,8 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No
|
||||
# - define add.at/multiply.at in terms of scatter_add/scatter_mul
|
||||
# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
|
||||
# - define all monoidal reductions in terms of lax.reduce
|
||||
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum)
|
||||
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod)
|
||||
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at)
|
||||
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at)
|
||||
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and)
|
||||
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or)
|
||||
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor)
|
||||
|
Loading…
x
Reference in New Issue
Block a user