Add an experimental lax.top_k operator. (#2280)

This commit is contained in:
Peter Hawkins 2020-02-20 17:15:25 -08:00 committed by GitHub
parent 8372a70079
commit af0967fdbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 1 deletions

View File

@ -1036,6 +1036,11 @@ def sort_key_val(keys, values, dimension=-1):
sorted_keys, sorted_values = result
return sorted_keys, sorted_values
def top_k(operand, k):
k = int(k)
if k < 0:
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
return top_k_p.bind(operand, k=k)
def tie_in(x, y):
return tie_in_p.bind(x, y)
@ -4034,7 +4039,6 @@ sort_p = standard_primitive(sort_shape, _input_dtype, 'sort')
ad.defjvp(sort_p, _sort_jvp_rule)
batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_key_val_abstract_eval(keys, values, dimension):
return raise_to_shaped(keys), raise_to_shaped(values)
@ -4106,6 +4110,19 @@ ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
def _top_k_abstract_eval(operand, k):
if len(operand.shape) == 0:
raise TypeError("top_k operand must have >= 1 dimension, got {}"
.format(operand.shape))
return raise_to_shaped(operand), ShapedArray(operand.shape, onp.int32)
top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
xla.translations[top_k_p] = partial(standard_translate, 'top_k')
def _tie_in_transpose_rule(t):
return [ad_util.zero, t]

View File

@ -18,6 +18,7 @@ import functools
from functools import partial
import itertools
from typing import Optional, cast
import unittest
from unittest import skip, SkipTest
from absl.testing import absltest
@ -1319,6 +1320,29 @@ class LaxTest(jtu.JaxTestCase):
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
for dtype in [onp.float32, onp.int32, onp.uint32]
for shape in [(3,), (5, 3)]
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
@unittest.skipIf(jax.lib.version <= (0, 1, 40), "Test requires jaxlib 0.1.40")
def testTopK(self, shape, dtype, k, rng_factory):
rng = rng_factory()
perm_rng = onp.random.RandomState(0)
def args_maker():
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
values = perm_rng.permutation(flat_values).reshape(shape)
return [values]
def reference_top_k(x):
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1]), shape)
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
op = lambda vs: lax.top_k(vs, k=k)
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),