mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect. Since approx_top_k can handle multi-rank tensors, the only mapping we need is to set the reduction_dim correctly. PiperOrigin-RevId: 440445041
This commit is contained in:
parent
6cb7526390
commit
0bfb3efcd7
@ -312,28 +312,23 @@ def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
|
||||
return sliced_vals, sliced_args
|
||||
|
||||
|
||||
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
|
||||
def _approx_top_k_batch_rule(batch_operands, batch_axes, *, k,
|
||||
reduction_dimension, recall_target, is_max_k,
|
||||
reduction_input_size_override, aggregate_to_topk):
|
||||
prototype_arg, new_bdim = next(
|
||||
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
||||
new_args = []
|
||||
for arg, bdim in zip(batched_args, batch_dims):
|
||||
if bdim is None:
|
||||
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
|
||||
new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims))
|
||||
else:
|
||||
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
||||
new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension)
|
||||
bdims = (new_bdim,) * len(new_args)
|
||||
return (approx_top_k_p.bind(
|
||||
*new_args,
|
||||
assert len(batch_operands) == 1
|
||||
assert len(batch_axes) == 1
|
||||
operand, = batch_operands
|
||||
batch_axis, = batch_axes
|
||||
dim_map = [d for d in range(operand.ndim) if d is not batch_axis]
|
||||
reduction_dimension = dim_map[reduction_dimension]
|
||||
return approx_top_k_p.bind(
|
||||
operand,
|
||||
k=k,
|
||||
reduction_dimension=new_reduction_dim,
|
||||
reduction_dimension=reduction_dimension,
|
||||
recall_target=recall_target,
|
||||
is_max_k=False,
|
||||
is_max_k=is_max_k,
|
||||
reduction_input_size_override=reduction_input_size_override,
|
||||
aggregate_to_topk=aggregate_to_topk), bdims)
|
||||
aggregate_to_topk=aggregate_to_topk), (batch_axis, batch_axis)
|
||||
|
||||
|
||||
# Slow jvp implementation using gather.
|
||||
|
@ -174,5 +174,58 @@ class AnnTest(jtu.JaxTestCase):
|
||||
self.assertGreater(ann_recall, recall)
|
||||
|
||||
|
||||
def test_vmap_before(self):
|
||||
batch = 4
|
||||
qy_size = 128
|
||||
db_size = 1024
|
||||
feature_dim = 32
|
||||
k = 10
|
||||
rng = jtu.rand_default(self.rng())
|
||||
qy = rng([batch, qy_size, feature_dim], np.float32)
|
||||
db = rng([batch, db_size, feature_dim], np.float32)
|
||||
recall = 0.95
|
||||
|
||||
# Create ground truth
|
||||
gt_scores = lax.dot_general(qy, db, (([2], [2]), ([0], [0])))
|
||||
_, gt_args = lax.top_k(gt_scores, k)
|
||||
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
|
||||
|
||||
# test target
|
||||
def approx_max_k(qy, db):
|
||||
scores = qy @ db.transpose()
|
||||
return lax.approx_max_k(scores, k)
|
||||
_, ann_args = jax.vmap(approx_max_k, (0, 0))(qy, db)
|
||||
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
|
||||
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
|
||||
self.assertGreater(ann_recall, recall)
|
||||
|
||||
|
||||
def test_vmap_after(self):
|
||||
batch = 4
|
||||
qy_size = 128
|
||||
db_size = 1024
|
||||
feature_dim = 32
|
||||
k = 10
|
||||
rng = jtu.rand_default(self.rng())
|
||||
qy = rng([qy_size, feature_dim, batch], np.float32)
|
||||
db = rng([db_size, feature_dim, batch], np.float32)
|
||||
recall = 0.95
|
||||
|
||||
# Create ground truth
|
||||
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
|
||||
_, gt_args = lax.top_k(gt_scores, k)
|
||||
gt_args = lax.transpose(gt_args, [2, 0, 1])
|
||||
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
|
||||
|
||||
# test target
|
||||
def approx_max_k(qy, db):
|
||||
scores = qy @ db.transpose()
|
||||
return lax.approx_max_k(scores, k)
|
||||
_, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
|
||||
ann_args = lax.transpose(ann_args, [2, 0, 1])
|
||||
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
|
||||
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
|
||||
self.assertGreater(ann_recall, recall)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user