18 Commits

Author SHA1 Message Date
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Felix Chern
8ac7422e26 [JAX] Disables large k test cases in ann_test.
Will investigate probability properties for the corner cases in the future.

PiperOrigin-RevId: 487302143
2022-11-09 11:32:47 -08:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
0bfb3efcd7 [JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
jax authors
d9f82f7b9b [JAX] Move experimental.ann.approx_*_k into lax.
Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
jax authors
8372b98c48 [JAX] Move ann.ann_recall back to tests.
The function is simple enough for users to implement their own on the host.

PiperOrigin-RevId: 430696789
2022-02-24 07:23:17 -08:00
jax authors
a0abe8e4ac [JAX] Move the ann recall computation to ann.py.
This function is very useful for our users to evaluate the ann results
against the standard ann datasets that provides the ground truth.

PiperOrigin-RevId: 425997236
2022-02-02 15:50:13 -08:00
jax authors
c12ca7f64c [XLA:TPU] Add 'aggregate_to_topk' option to ann in jax
Also adds a pmap test for demonstrating multi-tpu ann.

PiperOrigin-RevId: 425451716
2022-01-31 13:46:07 -08:00
Peter Hawkins
954cb9983b [JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays.
See https://github.com/google/jax/pull/8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406247725
2021-10-28 16:49:37 -07:00
Sharad Vikram
b6fa33fd2d [JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays.
See https://github.com/google/jax/pull/8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405995198
2021-10-27 15:26:52 -07:00
jax authors
789ce1e835 [XLA:TPU] Fix shape bug for k=1 in ApproxTopK
PiperOrigin-RevId: 405745255
2021-10-26 15:07:40 -07:00
jax authors
06b595321f [XLA:TPU] Support jvp/vjp in approx_top_k
Copies the jvp implementation lax.sort uses.
Left some comments for future optimizations

PiperOrigin-RevId: 404608289
2021-10-20 12:08:04 -07:00
jax authors
b09501f80e [XLA:TPU] Fix approx_top_k output slice.
ApproxTopK should slice the output to k on reduction dim.

PiperOrigin-RevId: 404371519
2021-10-19 14:50:34 -07:00
jax authors
e96f363242 [XLA:TPU] Adding approximate nearest neighbor search on TPU feature to JAX.
The JAX primitive would call the XLA python interface for ApproxTopK on TPU,
and fallbacked to sort-and-slice XLA implementation on other platforms.

Auto differntiation have two possible implementations and will be
submitted in seprated CLs.

PiperOrigin-RevId: 404263763
2021-10-19 08:09:43 -07:00