mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add an experimental lax.top_k operator. (#2280)
This commit is contained in:
parent
8372a70079
commit
af0967fdbf
@ -1036,6 +1036,11 @@ def sort_key_val(keys, values, dimension=-1):
|
||||
sorted_keys, sorted_values = result
|
||||
return sorted_keys, sorted_values
|
||||
|
||||
def top_k(operand, k):
|
||||
k = int(k)
|
||||
if k < 0:
|
||||
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
||||
return top_k_p.bind(operand, k=k)
|
||||
|
||||
def tie_in(x, y):
|
||||
return tie_in_p.bind(x, y)
|
||||
@ -4034,7 +4039,6 @@ sort_p = standard_primitive(sort_shape, _input_dtype, 'sort')
|
||||
ad.defjvp(sort_p, _sort_jvp_rule)
|
||||
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
|
||||
|
||||
def _sort_key_val_abstract_eval(keys, values, dimension):
|
||||
return raise_to_shaped(keys), raise_to_shaped(values)
|
||||
|
||||
@ -4106,6 +4110,19 @@ ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
|
||||
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
|
||||
|
||||
|
||||
def _top_k_abstract_eval(operand, k):
|
||||
if len(operand.shape) == 0:
|
||||
raise TypeError("top_k operand must have >= 1 dimension, got {}"
|
||||
.format(operand.shape))
|
||||
return raise_to_shaped(operand), ShapedArray(operand.shape, onp.int32)
|
||||
|
||||
top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
|
||||
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
||||
xla.translations[top_k_p] = partial(standard_translate, 'top_k')
|
||||
|
||||
|
||||
def _tie_in_transpose_rule(t):
|
||||
return [ad_util.zero, t]
|
||||
|
||||
|
@ -18,6 +18,7 @@ import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
from typing import Optional, cast
|
||||
import unittest
|
||||
from unittest import skip, SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -1319,6 +1320,29 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_k={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
|
||||
for dtype in [onp.float32, onp.int32, onp.uint32]
|
||||
for shape in [(3,), (5, 3)]
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@unittest.skipIf(jax.lib.version <= (0, 1, 40), "Test requires jaxlib 0.1.40")
|
||||
def testTopK(self, shape, dtype, k, rng_factory):
|
||||
rng = rng_factory()
|
||||
perm_rng = onp.random.RandomState(0)
|
||||
def args_maker():
|
||||
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
|
||||
values = perm_rng.permutation(flat_values).reshape(shape)
|
||||
return [values]
|
||||
def reference_top_k(x):
|
||||
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1]), shape)
|
||||
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
|
||||
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
|
||||
op = lambda vs: lax.top_k(vs, k=k)
|
||||
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user