[JAX] Fix batch logic for approx_min/max_k

Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
This commit is contained in:
jax authors 2022-04-08 13:50:04 -07:00
parent 6cb7526390
commit 0bfb3efcd7
2 changed files with 65 additions and 17 deletions

View File

@ -312,28 +312,23 @@ def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
return sliced_vals, sliced_args
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
def _approx_top_k_batch_rule(batch_operands, batch_axes, *, k,
reduction_dimension, recall_target, is_max_k,
reduction_input_size_override, aggregate_to_topk):
prototype_arg, new_bdim = next(
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
new_args = []
for arg, bdim in zip(batched_args, batch_dims):
if bdim is None:
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims))
else:
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension)
bdims = (new_bdim,) * len(new_args)
return (approx_top_k_p.bind(
*new_args,
assert len(batch_operands) == 1
assert len(batch_axes) == 1
operand, = batch_operands
batch_axis, = batch_axes
dim_map = [d for d in range(operand.ndim) if d is not batch_axis]
reduction_dimension = dim_map[reduction_dimension]
return approx_top_k_p.bind(
operand,
k=k,
reduction_dimension=new_reduction_dim,
reduction_dimension=reduction_dimension,
recall_target=recall_target,
is_max_k=False,
is_max_k=is_max_k,
reduction_input_size_override=reduction_input_size_override,
aggregate_to_topk=aggregate_to_topk), bdims)
aggregate_to_topk=aggregate_to_topk), (batch_axis, batch_axis)
# Slow jvp implementation using gather.

View File

@ -174,5 +174,58 @@ class AnnTest(jtu.JaxTestCase):
self.assertGreater(ann_recall, recall)
def test_vmap_before(self):
batch = 4
qy_size = 128
db_size = 1024
feature_dim = 32
k = 10
rng = jtu.rand_default(self.rng())
qy = rng([batch, qy_size, feature_dim], np.float32)
db = rng([batch, db_size, feature_dim], np.float32)
recall = 0.95
# Create ground truth
gt_scores = lax.dot_general(qy, db, (([2], [2]), ([0], [0])))
_, gt_args = lax.top_k(gt_scores, k)
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
# test target
def approx_max_k(qy, db):
scores = qy @ db.transpose()
return lax.approx_max_k(scores, k)
_, ann_args = jax.vmap(approx_max_k, (0, 0))(qy, db)
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
def test_vmap_after(self):
batch = 4
qy_size = 128
db_size = 1024
feature_dim = 32
k = 10
rng = jtu.rand_default(self.rng())
qy = rng([qy_size, feature_dim, batch], np.float32)
db = rng([db_size, feature_dim, batch], np.float32)
recall = 0.95
# Create ground truth
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
_, gt_args = lax.top_k(gt_scores, k)
gt_args = lax.transpose(gt_args, [2, 0, 1])
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
# test target
def approx_max_k(qy, db):
scores = qy @ db.transpose()
return lax.approx_max_k(scores, k)
_, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
ann_args = lax.transpose(ann_args, [2, 0, 1])
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())