Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.
PiperOrigin-RevId: 440445041
This function is very useful for our users to evaluate the ann results
against the standard ann datasets that provides the ground truth.
PiperOrigin-RevId: 425997236
See https://github.com/google/jax/pull/8043 for context as to why we are making this change.
The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:
* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.
PiperOrigin-RevId: 406247725
See https://github.com/google/jax/pull/8043 for context as to why we are making this change.
The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.
PiperOrigin-RevId: 405995198
The JAX primitive would call the XLA python interface for ApproxTopK on TPU,
and fallbacked to sort-and-slice XLA implementation on other platforms.
Auto differntiation have two possible implementations and will be
submitted in seprated CLs.
PiperOrigin-RevId: 404263763