diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a92bef85e..8f078ceee 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -11846,8 +11846,8 @@ def _is_valid_integer_index_for_slice(idx, size, mode): except: return False if shape == () and np.issubdtype(dtype, np.integer): - # For dynamic integer indices, dynamic_slice semantics require index clipping: - return mode in [None, 'promise_inbounds', 'clip'] + # For dynamic integer indices, semantics require promise_inbounds. + return mode in [None, 'promise_inbounds'] return False def _is_contiguous_slice(idx): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index ab625d10b..04225d6c5 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1675,5 +1675,17 @@ class IndexedUpdateTest(jtu.JaxTestCase): c = a.at[0].set(val) self.assertEqual(int(c[0]), val) + def testGradOfVmapOfScatter(self): + # Regression test for https://github.com/jax-ml/jax/issues/25878 + def f(x, i): + return x.at[i].get(mode='clip') + + x = jnp.array([1.0]) + i = jnp.array([1]) # out-of-bound index + expected = jnp.array([[1.0]]) + + self.assertArraysEqual(jax.jacrev(f)(x, i), expected) + self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())