mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
ufunc: fix implements wrapper for at
This commit is contained in:
parent
a5b8ce1208
commit
1f9a2dddb8
@ -258,7 +258,7 @@ class ufunc:
|
||||
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
|
||||
return _moveaxis(result, 0, axis)
|
||||
|
||||
@implements(np.ufunc.accumulate, module="numpy.ufunc")
|
||||
@implements(np.ufunc.at, module="numpy.ufunc")
|
||||
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
|
||||
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
|
||||
inplace: bool = True) -> Array:
|
||||
|
Loading…
x
Reference in New Issue
Block a user