mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Move experimental.ann.approx_*_k
into lax
.
Updated docs, tests and the example code snippets. PiperOrigin-RevId: 431781401
This commit is contained in:
parent
1a1bf122d9
commit
d9f82f7b9b
@ -1,13 +0,0 @@
|
||||
jax.experimental.ann module
|
||||
===========================
|
||||
|
||||
.. automodule:: jax.experimental.ann
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
approx_max_k
|
||||
approx_min_k
|
@ -6,13 +6,14 @@ jax.experimental package
|
||||
``jax.experimental.optix`` has been moved into its own Python package
|
||||
(https://github.com/deepmind/optax).
|
||||
|
||||
``jax.experimental.ann`` has been moved into ``jax.lax``.
|
||||
|
||||
Experimental Modules
|
||||
--------------------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.ann
|
||||
jax.experimental.global_device_array
|
||||
jax.experimental.host_callback
|
||||
jax.experimental.loops
|
||||
|
@ -26,6 +26,8 @@ Operators
|
||||
abs
|
||||
add
|
||||
acos
|
||||
approx_max_k
|
||||
approx_min_k
|
||||
argmax
|
||||
argmin
|
||||
asin
|
||||
|
@ -21,7 +21,6 @@ Usage::
|
||||
|
||||
import functools
|
||||
import jax
|
||||
from jax.experimental import ann
|
||||
|
||||
# MIPS := maximal inner product search
|
||||
# Inputs:
|
||||
@ -35,12 +34,7 @@ Usage::
|
||||
dists = jax.lax.dot(qy, db.transpose())
|
||||
# Computes max_k along the last dimension
|
||||
# returns (f32[qy_size, k], i32[qy_size, k])
|
||||
return ann.approx_max_k(dists, k=k, recall_target=recall_target)
|
||||
|
||||
# Obtains the top-10 dot products and its offsets in db.
|
||||
dot_products, neighbors = mips(qy, db, k=10)
|
||||
# Computes the recall against the true neighbors.
|
||||
recall = ann.ann_recall(neighbors, true_neighbors)
|
||||
return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
|
||||
|
||||
# Multi-core example
|
||||
# Inputs:
|
||||
@ -58,7 +52,7 @@ Usage::
|
||||
out_axes=(1, 1))
|
||||
def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
|
||||
dists = jax.lax.dot(qy, db.transpose())
|
||||
dists, neighbors = ann.approx_max_k(
|
||||
dists, neighbors = jax.lax.approx_max_k(
|
||||
dists, k=k, recall_target=recall_target,
|
||||
reduction_input_size_override=db_size)
|
||||
return (dists, neighbors + db_offset)
|
||||
@ -79,7 +73,8 @@ from functools import partial
|
||||
from typing import (Any, Tuple)
|
||||
|
||||
import numpy as np
|
||||
from jax import lax, core
|
||||
from jax import core
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import ad_util, dtypes
|
||||
|
||||
@ -125,12 +120,11 @@ def approx_max_k(operand: Array,
|
||||
>>> import functools
|
||||
>>> import jax
|
||||
>>> import numpy as np
|
||||
>>> from jax.experimental import ann
|
||||
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
|
||||
... def mips(qy, db, k=10, recall_target=0.95):
|
||||
... dists = jax.lax.dot(qy, db.transpose())
|
||||
... # returns (f32[qy_size, k], i32[qy_size, k])
|
||||
... return ann.approx_max_k(dists, k=k, recall_target=recall_target)
|
||||
... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
|
||||
>>>
|
||||
>>> qy = jax.numpy.array(np.random.rand(50, 64))
|
||||
>>> db = jax.numpy.array(np.random.rand(1024, 64))
|
||||
@ -185,11 +179,10 @@ def approx_min_k(operand: Array,
|
||||
>>> import functools
|
||||
>>> import jax
|
||||
>>> import numpy as np
|
||||
>>> from jax.experimental import ann
|
||||
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
|
||||
... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
|
||||
... dists = half_db_norms - jax.lax.dot(qy, db.transpose())
|
||||
... return ann.approx_min_k(dists, k=k, recall_target=recall_target)
|
||||
... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
|
||||
>>>
|
||||
>>> qy = jax.numpy.array(np.random.rand(50, 64))
|
||||
>>> db = jax.numpy.array(np.random.rand(1024, 64))
|
@ -986,6 +986,7 @@ tf_not_yet_impl = [
|
||||
# Not high priority?
|
||||
"after_all",
|
||||
"all_to_all",
|
||||
"approx_top_k",
|
||||
"create_token",
|
||||
"custom_transpose_call",
|
||||
"custom_vmap_call",
|
||||
|
@ -369,5 +369,10 @@ from jax._src.lax.other import (
|
||||
conv_general_dilated_local as conv_general_dilated_local,
|
||||
conv_general_dilated_patches as conv_general_dilated_patches
|
||||
)
|
||||
from jax._src.lax.ann import (
|
||||
approx_max_k as approx_max_k,
|
||||
approx_min_k as approx_min_k,
|
||||
approx_top_k_p as approx_top_k_p
|
||||
)
|
||||
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
||||
from jax.lax import linalg as linalg
|
||||
|
@ -21,7 +21,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import ann
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod
|
||||
|
||||
@ -80,7 +79,7 @@ class AnnTest(jtu.JaxTestCase):
|
||||
db = rng(db_shape, dtype)
|
||||
scores = lax.dot(qy, db)
|
||||
_, gt_args = lax.top_k(scores, k)
|
||||
_, ann_args = ann.approx_max_k(scores, k, recall_target=recall)
|
||||
_, ann_args = lax.approx_max_k(scores, k, recall_target=recall)
|
||||
self.assertEqual(k, len(ann_args[0]))
|
||||
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
|
||||
self.assertGreater(ann_recall, recall)
|
||||
@ -103,7 +102,7 @@ class AnnTest(jtu.JaxTestCase):
|
||||
db = rng(db_shape, dtype)
|
||||
scores = lax.dot(qy, db)
|
||||
_, gt_args = lax.top_k(-scores, k)
|
||||
_, ann_args = ann.approx_min_k(scores, k, recall_target=recall)
|
||||
_, ann_args = lax.approx_min_k(scores, k, recall_target=recall)
|
||||
self.assertEqual(k, len(ann_args[0]))
|
||||
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
|
||||
self.assertGreater(ann_recall, recall)
|
||||
@ -122,9 +121,9 @@ class AnnTest(jtu.JaxTestCase):
|
||||
vals = np.arange(prod(shape), dtype=dtype)
|
||||
vals = self.rng().permutation(vals).reshape(shape)
|
||||
if is_max_k:
|
||||
fn = lambda vs: ann.approx_max_k(vs, k=k)[0]
|
||||
fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
|
||||
else:
|
||||
fn = lambda vs: ann.approx_min_k(vs, k=k)[0]
|
||||
fn = lambda vs: lax.approx_min_k(vs, k=k)[0]
|
||||
jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
||||
|
||||
@ -153,10 +152,13 @@ class AnnTest(jtu.JaxTestCase):
|
||||
db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device
|
||||
def parallel_topk(qy, db, db_offset):
|
||||
scores = lax.dot_general(qy, db, (([1],[1]),([],[])))
|
||||
ann_vals, ann_args = ann.approx_min_k(scores, k=k, reduction_dimension=1,
|
||||
recall_target=recall,
|
||||
reduction_input_size_override=db_size,
|
||||
aggregate_to_topk=False)
|
||||
ann_vals, ann_args = lax.approx_min_k(
|
||||
scores,
|
||||
k=k,
|
||||
reduction_dimension=1,
|
||||
recall_target=recall,
|
||||
reduction_input_size_override=db_size,
|
||||
aggregate_to_topk=False)
|
||||
return (ann_vals, ann_args + db_offset)
|
||||
# shape = qy_size, num_devices, approx_dp
|
||||
ann_vals, ann_args = jax.pmap(
|
||||
|
Loading…
x
Reference in New Issue
Block a user