Peter Hawkins 74346f464b [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.

PiperOrigin-RevId: 443390170
2022-04-21 08:52:14 -07:00
..
2022-03-10 10:39:52 -05:00
2022-04-15 04:17:27 +08:00
2022-04-13 18:31:27 +00:00