Merge pull request #25880 from jakevdp:fix-gather

PiperOrigin-RevId: 715804120
This commit is contained in:
jax authors 2025-01-15 08:10:44 -08:00
commit 2e5e4799fd
2 changed files with 14 additions and 2 deletions

View File

@ -11846,8 +11846,8 @@ def _is_valid_integer_index_for_slice(idx, size, mode):
except:
return False
if shape == () and np.issubdtype(dtype, np.integer):
# For dynamic integer indices, dynamic_slice semantics require index clipping:
return mode in [None, 'promise_inbounds', 'clip']
# For dynamic integer indices, semantics require promise_inbounds.
return mode in [None, 'promise_inbounds']
return False
def _is_contiguous_slice(idx):

View File

@ -1675,5 +1675,17 @@ class IndexedUpdateTest(jtu.JaxTestCase):
c = a.at[0].set(val)
self.assertEqual(int(c[0]), val)
def testGradOfVmapOfScatter(self):
# Regression test for https://github.com/jax-ml/jax/issues/25878
def f(x, i):
return x.at[i].get(mode='clip')
x = jnp.array([1.0])
i = jnp.array([1]) # out-of-bound index
expected = jnp.array([[1.0]])
self.assertArraysEqual(jax.jacrev(f)(x, i), expected)
self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())