mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Revert previous change
PiperOrigin-RevId: 435397906
This commit is contained in:
parent
250ef019b5
commit
c3a4a6e63d
@ -4043,14 +4043,7 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
||||
return top_k(operand, k=k), (bdim, bdim)
|
||||
|
||||
def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
|
||||
x_shape = ctx.builder.get_shape(x).dimensions()
|
||||
batchdims = x_shape[:-1]
|
||||
if batchdims:
|
||||
# TODO(b/224554623): XLA does not support top-k beyond 2D, collapse the
|
||||
# batch dimensions here to get better performance (otherwise XLA uses sort).
|
||||
x = xops.Reshape(x, (prod(batchdims), x_shape[-1]))
|
||||
ks, idxs = xla.xla_destructure(ctx.builder, xops.TopK(x, k))
|
||||
return xops.Reshape(ks, batchdims+(k,)), xops.Reshape(idxs, batchdims+(k,))
|
||||
return xla.xla_destructure(ctx.builder, xops.TopK(x, k))
|
||||
|
||||
top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
|
@ -2104,7 +2104,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype), k),
|
||||
"shape": shape, "dtype": dtype, "k": k}
|
||||
for dtype in [np.float32, np.int32, np.uint32]
|
||||
for shape in [(3,), (5, 3), (7, 5, 3)]
|
||||
for shape in [(3,), (5, 3)]
|
||||
for k in [1, 3]))
|
||||
def testTopK(self, shape, dtype, k):
|
||||
def args_maker():
|
||||
|
Loading…
x
Reference in New Issue
Block a user