jnp.frompyfunc: fix .at() edge case

This commit is contained in:
Jake VanderPlas 2023-10-09 11:24:01 -07:00
parent 84b58ec7f3
commit 41a7d66686
2 changed files with 18 additions and 1 deletions

View File

@ -20,6 +20,7 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
>>> arr = jnp.add.at(arr, ind, val, inplace=False)
"""
from functools import partial
import math
import operator
from typing import Any, Callable, Optional
@ -264,6 +265,7 @@ class ufunc:
return self._at_via_scan(a, indices, b)
def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array:
assert len(args) in {0, 1}
check_arraylike(f"{self.__name__}.at", a, *args)
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype)
@ -277,7 +279,9 @@ class ufunc:
if not shape:
return a.at[indices].set(self._call(a.at[indices].get(), *args))
args = tuple(_broadcast_to(arg, shape).ravel() for arg in args)
if args:
arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):]))
args = (arg.reshape(math.prod(shape), *args[0].shape[len(shape):]),)
indices = [idx if isinstance(idx, slice) else _broadcast_to(idx, shape).ravel() for idx in indices]
def scan_fun(carry, x):

View File

@ -289,6 +289,19 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def test_at_broadcasting(self):
# Regression test for https://github.com/google/jax/issues/18004
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
np.arange(9.0).reshape(3, 3)]
def np_fun(x, idx, y):
x_copy = np.copy(x)
np.add.at(x_copy, idx, y)
return x_copy
jnp_fun = partial(jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0).at, inplace=False)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}