Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.
The previous behavior can be approximated using the "clip" or "wrap" fill modes.
PiperOrigin-RevId: 445130143
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.
This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.
After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.
Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.
Issues: https://github.com/google/jax/issues/278https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
This change align the behavior of `ravel_multi_index`, `split` and `indices` to their `numpy` counterparts.
Also ensure size argument of `nonzero` should be integer.
The changes with `*space` are only simplification
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.
PiperOrigin-RevId: 443390170
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
--
391dea76bc8fe264cf26ec93d42147f87847894d by Peter Hawkins <phawkins@google.com>:
Update version numbers after jax/jaxlib 0.3.7 release.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/10324 from hawkinsp:jaxlib 391dea76bc8fe264cf26ec93d42147f87847894d
PiperOrigin-RevId: 442311051
JAX kept an older name around (.block_host_until_ready()) in parallel with the new name (.block_until_ready()) to avoid breaking users. Deprecate it so we only have one name.
PiperOrigin-RevId: 433228545