mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Fix uint32 scatter assignment
This commit is contained in:
parent
8abe03daed
commit
e061b91ffc
@ -64,9 +64,13 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
|
||||
Returns:
|
||||
An ndarray representing an updated `x` after performing the scatter-update.
|
||||
"""
|
||||
|
||||
x = jnp.asarray(x)
|
||||
y = jnp.asarray(y)
|
||||
if (isinstance(y, int) and np.issubdtype(x.dtype, np.integer) and
|
||||
np.iinfo(x.dtype).min <= y <= np.iinfo(x.dtype).max):
|
||||
y = jnp.asarray(y, dtype=x.dtype)
|
||||
else:
|
||||
y = jnp.asarray(y)
|
||||
|
||||
# XLA gathers and scatters are very similar in structure; the scatter logic
|
||||
# is more or less a transpose of the gather equivalent.
|
||||
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
|
||||
|
@ -1467,5 +1467,16 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
y = jnp.zeros(8)
|
||||
self.assertArraysEqual(fn(y), jax.jit(fn)(y))
|
||||
|
||||
def testScatterValuesCastToTargetDType(self):
|
||||
# https://github.com/google/jax/issues/15505
|
||||
a = jnp.zeros(1, dtype=jnp.uint32)
|
||||
val = 2**32 - 1 # too large for int32
|
||||
|
||||
b = a.at[0].set(jnp.uint32(val))
|
||||
self.assertEqual(int(b[0]), val)
|
||||
|
||||
c = a.at[0].set(val)
|
||||
self.assertEqual(int(c[0]), val)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user