rocm_jax/tests/ann_test.py
Jake VanderPlas f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00

210 lines
7.1 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from absl.testing import absltest
import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning,message=".*jit-of-pmap.*")
def compute_recall(result_neighbors, ground_truth_neighbors) -> float:
"""Computes the recall of an approximate nearest neighbor search.
Args:
result_neighbors: int32 numpy array of the shape [num_queries,
neighbors_per_query] where the values are the indices of the dataset.
ground_truth_neighbors: int32 numpy array of with shape [num_queries,
ground_truth_neighbors_per_query] where the values are the indices of the
dataset.
Returns:
The recall.
"""
assert len(
result_neighbors.shape) == 2, "shape = [num_queries, neighbors_per_query]"
assert len(ground_truth_neighbors.shape
) == 2, "shape = [num_queries, ground_truth_neighbors_per_query]"
assert result_neighbors.shape[0] == ground_truth_neighbors.shape[0]
gt_sets = [set(np.asarray(x)) for x in ground_truth_neighbors]
hits = sum(len([x
for x in nn_per_q
if x.item() in gt_sets[q]])
for q, nn_per_q in enumerate(result_neighbors))
return hits / ground_truth_neighbors.size
class AnnTest(jtu.JaxTestCase):
# TODO(b/258315194) Investigate probability property when input is around
# few thousands.
@jtu.sample_product(
qy_shape=[(200, 128), (128, 128)],
db_shape=[(128, 500), (128, 3000)],
dtype=jtu.dtypes.all_floating,
k=[1, 10],
recall=[0.95],
)
def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall):
rng = jtu.rand_default(self.rng())
qy = rng(qy_shape, dtype)
db = rng(db_shape, dtype)
scores = lax.dot(qy, db)
_, gt_args = lax.top_k(scores, k)
_, ann_args = lax.approx_max_k(scores, k, recall_target=recall)
self.assertEqual(k, len(ann_args[0]))
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreaterEqual(ann_recall, recall*0.9)
@jtu.sample_product(
qy_shape=[(200, 128), (128, 128)],
db_shape=[(128, 500), (128, 3000)],
dtype=jtu.dtypes.all_floating,
k=[1, 10],
recall=[0.95],
)
def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall):
rng = jtu.rand_default(self.rng())
qy = rng(qy_shape, dtype)
db = rng(db_shape, dtype)
scores = lax.dot(qy, db)
_, gt_args = lax.top_k(-scores, k)
_, ann_args = lax.approx_min_k(scores, k, recall_target=recall)
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreaterEqual(ann_recall, recall*0.9)
@jtu.sample_product(
dtype=[np.float32],
shape=[(4,), (5, 5), (2, 1, 4)],
k=[1, 3],
is_max_k=[True, False],
)
def test_autodiff(self, shape, dtype, k, is_max_k):
vals = np.arange(math.prod(shape), dtype=dtype)
vals = self.rng().permutation(vals).reshape(shape)
if is_max_k:
fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
else:
fn = lambda vs: lax.approx_min_k(vs, k=k)[0]
jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
@jtu.sample_product(
qy_shape=[(200, 128), (128, 128)],
db_shape=[(2048, 128)],
dtype=jtu.dtypes.all_floating,
k=[1, 10],
recall=[0.9, 0.95],
)
def test_pmap(self, qy_shape, db_shape, dtype, k, recall):
num_devices = jax.device_count()
rng = jtu.rand_default(self.rng())
qy = rng(qy_shape, dtype)
db = rng(db_shape, dtype)
db_size = db.shape[0]
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
_, gt_args = lax.top_k(-gt_scores, k) # negate the score to get min-k
db_per_device = db_size//num_devices
sharded_db = db.reshape(num_devices, db_per_device, 128)
db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device
def parallel_topk(qy, db, db_offset):
scores = lax.dot_general(qy, db, (([1],[1]),([],[])))
ann_vals, ann_args = lax.approx_min_k(
scores,
k=k,
reduction_dimension=1,
recall_target=recall,
reduction_input_size_override=db_size,
aggregate_to_topk=False)
return (ann_vals, ann_args + db_offset)
# shape = qy_size, num_devices, approx_dp
ann_vals, ann_args = jax.pmap(
parallel_topk,
in_axes=(None, 0, 0),
out_axes=(1, 1))(qy, sharded_db, db_offsets)
# collapse num_devices and approx_dp
ann_vals = lax.collapse(ann_vals, 1, 3)
ann_args = lax.collapse(ann_args, 1, 3)
ann_vals, ann_args = lax.sort_key_val(ann_vals, ann_args, dimension=1)
ann_args = lax.slice_in_dim(ann_args, start_index=0, limit_index=k, axis=1)
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
def test_vmap_before(self):
batch = 4
qy_size = 128
db_size = 1024
feature_dim = 32
k = 10
rng = jtu.rand_default(self.rng())
qy = rng([batch, qy_size, feature_dim], np.float32)
db = rng([batch, db_size, feature_dim], np.float32)
recall = 0.95
# Create ground truth
gt_scores = lax.dot_general(qy, db, (([2], [2]), ([0], [0])))
_, gt_args = lax.top_k(gt_scores, k)
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
# test target
def approx_max_k(qy, db):
scores = qy @ db.transpose()
return lax.approx_max_k(scores, k)
_, ann_args = jax.vmap(approx_max_k, (0, 0))(qy, db)
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
def test_vmap_after(self):
batch = 4
qy_size = 128
db_size = 1024
feature_dim = 32
k = 10
rng = jtu.rand_default(self.rng())
qy = rng([qy_size, feature_dim, batch], np.float32)
db = rng([db_size, feature_dim, batch], np.float32)
recall = 0.95
# Create ground truth
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
_, gt_args = lax.top_k(gt_scores, k)
gt_args = lax.transpose(gt_args, [2, 0, 1])
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
# test target
def approx_max_k(qy, db):
scores = qy @ db.transpose()
return lax.approx_max_k(scores, k)
_, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
ann_args = lax.transpose(ann_args, [2, 0, 1])
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())