mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
This commit is contained in:
commit
2e5e4799fd
@ -11846,8 +11846,8 @@ def _is_valid_integer_index_for_slice(idx, size, mode):
|
||||
except:
|
||||
return False
|
||||
if shape == () and np.issubdtype(dtype, np.integer):
|
||||
# For dynamic integer indices, dynamic_slice semantics require index clipping:
|
||||
return mode in [None, 'promise_inbounds', 'clip']
|
||||
# For dynamic integer indices, semantics require promise_inbounds.
|
||||
return mode in [None, 'promise_inbounds']
|
||||
return False
|
||||
|
||||
def _is_contiguous_slice(idx):
|
||||
|
@ -1675,5 +1675,17 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
c = a.at[0].set(val)
|
||||
self.assertEqual(int(c[0]), val)
|
||||
|
||||
def testGradOfVmapOfScatter(self):
|
||||
# Regression test for https://github.com/jax-ml/jax/issues/25878
|
||||
def f(x, i):
|
||||
return x.at[i].get(mode='clip')
|
||||
|
||||
x = jnp.array([1.0])
|
||||
i = jnp.array([1]) # out-of-bound index
|
||||
expected = jnp.array([[1.0]])
|
||||
|
||||
self.assertArraysEqual(jax.jacrev(f)(x, i), expected)
|
||||
self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user