From 1f9a2dddb899c883f0242a7a7af6c2b13d4f52ec Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 5 Apr 2024 09:42:49 -0700 Subject: [PATCH] ufunc: fix implements wrapper for at --- jax/_src/numpy/ufunc_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2d3eb1edf..29f5278bc 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -258,7 +258,7 @@ class ufunc: _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @implements(np.ufunc.accumulate, module="numpy.ufunc") + @implements(np.ufunc.at, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: