Fix uint32 scatter assignment

This commit is contained in:
Jake VanderPlas 2023-04-10 14:24:26 -07:00
parent 8abe03daed
commit e061b91ffc
2 changed files with 17 additions and 2 deletions

View File

@ -64,9 +64,13 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
Returns:
An ndarray representing an updated `x` after performing the scatter-update.
"""
x = jnp.asarray(x)
y = jnp.asarray(y)
if (isinstance(y, int) and np.issubdtype(x.dtype, np.integer) and
np.iinfo(x.dtype).min <= y <= np.iinfo(x.dtype).max):
y = jnp.asarray(y, dtype=x.dtype)
else:
y = jnp.asarray(y)
# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)

View File

@ -1467,5 +1467,16 @@ class IndexedUpdateTest(jtu.JaxTestCase):
y = jnp.zeros(8)
self.assertArraysEqual(fn(y), jax.jit(fn)(y))
def testScatterValuesCastToTargetDType(self):
# https://github.com/google/jax/issues/15505
a = jnp.zeros(1, dtype=jnp.uint32)
val = 2**32 - 1 # too large for int32
b = a.at[0].set(jnp.uint32(val))
self.assertEqual(int(b[0]), val)
c = a.at[0].set(val)
self.assertEqual(int(c[0]), val)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())