Fix abstract evaluation rule for lax.top_k. (#2290)

This commit is contained in:
Peter Hawkins 2020-02-24 07:31:46 -08:00 committed by GitHub
parent f6e1d01f94
commit 0416d2a5f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View File

@ -4109,12 +4109,19 @@ ad.primitive_jvps[sort_key_val_p] = _sort_key_val_jvp
ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
def _top_k_abstract_eval(operand, k):
if k < 0:
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
if len(operand.shape) == 0:
raise TypeError("top_k operand must have >= 1 dimension, got {}"
.format(operand.shape))
return raise_to_shaped(operand), ShapedArray(operand.shape, onp.int32)
shape = list(operand.shape)
if shape[-1] < k:
msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
raise ValueError(msg.format(k, shape))
shape[-1] = k
return (ShapedArray(shape, operand.dtype),
ShapedArray(shape, onp.dtype(onp.int32)))
top_k_p = Primitive('top_k')
top_k_p.multiple_results = True

View File

@ -1328,7 +1328,7 @@ class LaxTest(jtu.JaxTestCase):
for shape in [(3,), (5, 3)]
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
@unittest.skipIf(jax.lib.version <= (0, 1, 40), "Test requires jaxlib 0.1.40")
@unittest.skipIf(jax.lib.version < (0, 1, 40), "Test requires jaxlib 0.1.40")
def testTopK(self, shape, dtype, k, rng_factory):
rng = rng_factory()
perm_rng = onp.random.RandomState(0)
@ -1342,6 +1342,7 @@ class LaxTest(jtu.JaxTestCase):
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
op = lambda vs: lax.top_k(vs, k=k)
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}"