mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix abstract evaluation rule for lax.top_k. (#2290)
This commit is contained in:
parent
f6e1d01f94
commit
0416d2a5f2
@ -4109,12 +4109,19 @@ ad.primitive_jvps[sort_key_val_p] = _sort_key_val_jvp
|
||||
ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
|
||||
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
|
||||
|
||||
|
||||
def _top_k_abstract_eval(operand, k):
|
||||
if k < 0:
|
||||
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
||||
if len(operand.shape) == 0:
|
||||
raise TypeError("top_k operand must have >= 1 dimension, got {}"
|
||||
.format(operand.shape))
|
||||
return raise_to_shaped(operand), ShapedArray(operand.shape, onp.int32)
|
||||
shape = list(operand.shape)
|
||||
if shape[-1] < k:
|
||||
msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
|
||||
raise ValueError(msg.format(k, shape))
|
||||
shape[-1] = k
|
||||
return (ShapedArray(shape, operand.dtype),
|
||||
ShapedArray(shape, onp.dtype(onp.int32)))
|
||||
|
||||
top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
|
@ -1328,7 +1328,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for shape in [(3,), (5, 3)]
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@unittest.skipIf(jax.lib.version <= (0, 1, 40), "Test requires jaxlib 0.1.40")
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 40), "Test requires jaxlib 0.1.40")
|
||||
def testTopK(self, shape, dtype, k, rng_factory):
|
||||
rng = rng_factory()
|
||||
perm_rng = onp.random.RandomState(0)
|
||||
@ -1342,6 +1342,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
|
||||
op = lambda vs: lax.top_k(vs, k=k)
|
||||
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user