ufunc: fix implements wrapper for at

This commit is contained in:
Jake VanderPlas 2024-04-05 09:42:49 -07:00
parent a5b8ce1208
commit 1f9a2dddb8

View File

@ -258,7 +258,7 @@ class ufunc:
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
return _moveaxis(result, 0, axis)
@implements(np.ufunc.accumulate, module="numpy.ufunc")
@implements(np.ufunc.at, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
inplace: bool = True) -> Array: