[JAX] Move experimental.ann.approx_*_k into lax.

Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
This commit is contained in:
jax authors 2022-03-01 14:46:04 -08:00
parent 1a1bf122d9
commit d9f82f7b9b
7 changed files with 27 additions and 36 deletions

View File

@ -1,13 +0,0 @@
jax.experimental.ann module
===========================
.. automodule:: jax.experimental.ann
API
---
.. autosummary::
:toctree: _autosummary
approx_max_k
approx_min_k

View File

@ -6,13 +6,14 @@ jax.experimental package
``jax.experimental.optix`` has been moved into its own Python package
(https://github.com/deepmind/optax).
``jax.experimental.ann`` has been moved into ``jax.lax``.
Experimental Modules
--------------------
.. toctree::
:maxdepth: 1
jax.experimental.ann
jax.experimental.global_device_array
jax.experimental.host_callback
jax.experimental.loops

View File

@ -26,6 +26,8 @@ Operators
abs
add
acos
approx_max_k
approx_min_k
argmax
argmin
asin

View File

@ -21,7 +21,6 @@ Usage::
import functools
import jax
from jax.experimental import ann
# MIPS := maximal inner product search
# Inputs:
@ -35,12 +34,7 @@ Usage::
dists = jax.lax.dot(qy, db.transpose())
# Computes max_k along the last dimension
# returns (f32[qy_size, k], i32[qy_size, k])
return ann.approx_max_k(dists, k=k, recall_target=recall_target)
# Obtains the top-10 dot products and its offsets in db.
dot_products, neighbors = mips(qy, db, k=10)
# Computes the recall against the true neighbors.
recall = ann.ann_recall(neighbors, true_neighbors)
return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
# Multi-core example
# Inputs:
@ -58,7 +52,7 @@ Usage::
out_axes=(1, 1))
def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
dists = jax.lax.dot(qy, db.transpose())
dists, neighbors = ann.approx_max_k(
dists, neighbors = jax.lax.approx_max_k(
dists, k=k, recall_target=recall_target,
reduction_input_size_override=db_size)
return (dists, neighbors + db_offset)
@ -79,7 +73,8 @@ from functools import partial
from typing import (Any, Tuple)
import numpy as np
from jax import lax, core
from jax import core
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src import ad_util, dtypes
@ -125,12 +120,11 @@ def approx_max_k(operand: Array,
>>> import functools
>>> import jax
>>> import numpy as np
>>> from jax.experimental import ann
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def mips(qy, db, k=10, recall_target=0.95):
... dists = jax.lax.dot(qy, db.transpose())
... # returns (f32[qy_size, k], i32[qy_size, k])
... return ann.approx_max_k(dists, k=k, recall_target=recall_target)
... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
@ -185,11 +179,10 @@ def approx_min_k(operand: Array,
>>> import functools
>>> import jax
>>> import numpy as np
>>> from jax.experimental import ann
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
... dists = half_db_norms - jax.lax.dot(qy, db.transpose())
... return ann.approx_min_k(dists, k=k, recall_target=recall_target)
... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))

View File

@ -986,6 +986,7 @@ tf_not_yet_impl = [
# Not high priority?
"after_all",
"all_to_all",
"approx_top_k",
"create_token",
"custom_transpose_call",
"custom_vmap_call",

View File

@ -369,5 +369,10 @@ from jax._src.lax.other import (
conv_general_dilated_local as conv_general_dilated_local,
conv_general_dilated_patches as conv_general_dilated_patches
)
from jax._src.lax.ann import (
approx_max_k as approx_max_k,
approx_min_k as approx_min_k,
approx_top_k_p as approx_top_k_p
)
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from jax.lax import linalg as linalg

View File

@ -21,7 +21,6 @@ import numpy as np
import jax
from jax import lax
from jax.experimental import ann
from jax._src import test_util as jtu
from jax._src.util import prod
@ -80,7 +79,7 @@ class AnnTest(jtu.JaxTestCase):
db = rng(db_shape, dtype)
scores = lax.dot(qy, db)
_, gt_args = lax.top_k(scores, k)
_, ann_args = ann.approx_max_k(scores, k, recall_target=recall)
_, ann_args = lax.approx_max_k(scores, k, recall_target=recall)
self.assertEqual(k, len(ann_args[0]))
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
@ -103,7 +102,7 @@ class AnnTest(jtu.JaxTestCase):
db = rng(db_shape, dtype)
scores = lax.dot(qy, db)
_, gt_args = lax.top_k(-scores, k)
_, ann_args = ann.approx_min_k(scores, k, recall_target=recall)
_, ann_args = lax.approx_min_k(scores, k, recall_target=recall)
self.assertEqual(k, len(ann_args[0]))
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
@ -122,9 +121,9 @@ class AnnTest(jtu.JaxTestCase):
vals = np.arange(prod(shape), dtype=dtype)
vals = self.rng().permutation(vals).reshape(shape)
if is_max_k:
fn = lambda vs: ann.approx_max_k(vs, k=k)[0]
fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
else:
fn = lambda vs: ann.approx_min_k(vs, k=k)[0]
fn = lambda vs: lax.approx_min_k(vs, k=k)[0]
jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
@ -153,10 +152,13 @@ class AnnTest(jtu.JaxTestCase):
db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device
def parallel_topk(qy, db, db_offset):
scores = lax.dot_general(qy, db, (([1],[1]),([],[])))
ann_vals, ann_args = ann.approx_min_k(scores, k=k, reduction_dimension=1,
recall_target=recall,
reduction_input_size_override=db_size,
aggregate_to_topk=False)
ann_vals, ann_args = lax.approx_min_k(
scores,
k=k,
reduction_dimension=1,
recall_target=recall,
reduction_input_size_override=db_size,
aggregate_to_topk=False)
return (ann_vals, ann_args + db_offset)
# shape = qy_size, num_devices, approx_dp
ann_vals, ann_args = jax.pmap(