mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jnp.frompyfunc: fix .at() edge case
This commit is contained in:
parent
84b58ec7f3
commit
41a7d66686
@ -20,6 +20,7 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
|
||||
>>> arr = jnp.add.at(arr, ind, val, inplace=False)
|
||||
"""
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@ -264,6 +265,7 @@ class ufunc:
|
||||
return self._at_via_scan(a, indices, b)
|
||||
|
||||
def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array:
|
||||
assert len(args) in {0, 1}
|
||||
check_arraylike(f"{self.__name__}.at", a, *args)
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
|
||||
a = lax_internal.asarray(a).astype(dtype)
|
||||
@ -277,7 +279,9 @@ class ufunc:
|
||||
if not shape:
|
||||
return a.at[indices].set(self._call(a.at[indices].get(), *args))
|
||||
|
||||
args = tuple(_broadcast_to(arg, shape).ravel() for arg in args)
|
||||
if args:
|
||||
arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):]))
|
||||
args = (arg.reshape(math.prod(shape), *args[0].shape[len(shape):]),)
|
||||
indices = [idx if isinstance(idx, slice) else _broadcast_to(idx, shape).ravel() for idx in indices]
|
||||
|
||||
def scan_fun(carry, x):
|
||||
|
@ -289,6 +289,19 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def test_at_broadcasting(self):
|
||||
# Regression test for https://github.com/google/jax/issues/18004
|
||||
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
|
||||
np.arange(9.0).reshape(3, 3)]
|
||||
def np_fun(x, idx, y):
|
||||
x_copy = np.copy(x)
|
||||
np.add.at(x_copy, idx, y)
|
||||
return x_copy
|
||||
jnp_fun = partial(jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0).at, inplace=False)
|
||||
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
|
Loading…
x
Reference in New Issue
Block a user