Revert previous change

PiperOrigin-RevId: 435397906
This commit is contained in:
Thomas Köppe 2022-03-17 11:19:15 -07:00 committed by jax authors
parent 250ef019b5
commit c3a4a6e63d
2 changed files with 2 additions and 9 deletions

View File

@ -4043,14 +4043,7 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k):
return top_k(operand, k=k), (bdim, bdim)
def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
x_shape = ctx.builder.get_shape(x).dimensions()
batchdims = x_shape[:-1]
if batchdims:
# TODO(b/224554623): XLA does not support top-k beyond 2D, collapse the
# batch dimensions here to get better performance (otherwise XLA uses sort).
x = xops.Reshape(x, (prod(batchdims), x_shape[-1]))
ks, idxs = xla.xla_destructure(ctx.builder, xops.TopK(x, k))
return xops.Reshape(ks, batchdims+(k,)), xops.Reshape(idxs, batchdims+(k,))
return xla.xla_destructure(ctx.builder, xops.TopK(x, k))
top_k_p = Primitive('top_k')
top_k_p.multiple_results = True

View File

@ -2104,7 +2104,7 @@ class LaxTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype), k),
"shape": shape, "dtype": dtype, "k": k}
for dtype in [np.float32, np.int32, np.uint32]
for shape in [(3,), (5, 3), (7, 5, 3)]
for shape in [(3,), (5, 3)]
for k in [1, 3]))
def testTopK(self, shape, dtype, k):
def args_maker():