Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.
The previous behavior can be approximated using the "clip" or "wrap" fill modes.
PiperOrigin-RevId: 445130143
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.
This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.
After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.
Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.
Issues: https://github.com/google/jax/issues/278https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
--
4680b86ff7f468429a0820b4f8c7f64ffd1a1cad by Matthew Johnson <mattjj@google.com>:
[remove-units] prevent scan partial eval from introducing units
PiperOrigin-RevId: 444698613
a87b21148c doesn't notice `_scatter_add_lower_gpu` using `mlir.lower_fun` instead of `xla.lower_fun`.
I follow the change done in that commit for _scatter_lower.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.
PiperOrigin-RevId: 442222533
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):
```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.
PiperOrigin-RevId: 441474701
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.
PiperOrigin-RevId: 440445041