diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 0ef7aeba8..42c14aef1 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -20,6 +20,7 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. >>> arr = jnp.add.at(arr, ind, val, inplace=False) """ from functools import partial +import math import operator from typing import Any, Callable, Optional @@ -264,6 +265,7 @@ class ufunc: return self._at_via_scan(a, indices, b) def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: + assert len(args) in {0, 1} check_arraylike(f"{self.__name__}.at", a, *args) dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) @@ -277,7 +279,9 @@ class ufunc: if not shape: return a.at[indices].set(self._call(a.at[indices].get(), *args)) - args = tuple(_broadcast_to(arg, shape).ravel() for arg in args) + if args: + arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):])) + args = (arg.reshape(math.prod(shape), *args[0].shape[len(shape):]),) indices = [idx if isinstance(idx, slice) else _broadcast_to(idx, shape).ravel() for idx in indices] def scan_fun(carry, x): diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index cace4d8d4..8f90a997e 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -289,6 +289,19 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def test_at_broadcasting(self): + # Regression test for https://github.com/google/jax/issues/18004 + args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), + np.arange(9.0).reshape(3, 3)] + def np_fun(x, idx, y): + x_copy = np.copy(x) + np.add.at(x_copy, idx, y) + return x_copy + jnp_fun = partial(jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0).at, inplace=False) + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis}