mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add top_k jvp and batching rules
This commit is contained in:
parent
89e3840e63
commit
dddad2a3dc
@ -123,6 +123,7 @@ Operators
|
||||
sub
|
||||
tan
|
||||
tie_in
|
||||
top_k
|
||||
transpose
|
||||
|
||||
|
||||
|
@ -1183,7 +1183,8 @@ def sort_key_val(keys: Array, values: Array, dimension: int = -1) -> Array:
|
||||
sorted_keys, sorted_values = result
|
||||
return sorted_keys, sorted_values
|
||||
|
||||
def top_k(operand: Array, k: int) -> Array:
|
||||
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
||||
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
|
||||
k = int(k)
|
||||
if k < 0:
|
||||
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
||||
@ -4618,12 +4619,53 @@ def _top_k_abstract_eval(operand, *, k):
|
||||
return (ShapedArray(shape, operand.dtype),
|
||||
ShapedArray(shape, onp.dtype(onp.int32)))
|
||||
|
||||
def _top_k_jvp(primals, tangents, *, k):
|
||||
operand, = primals
|
||||
tangent, = tangents
|
||||
primals_out = top_k(operand, k)
|
||||
if tangent is ad_util.zero:
|
||||
tangents_out = (ad_util.zero, ad_util.zero)
|
||||
else:
|
||||
_, k_idxs = primals_out
|
||||
idx_shape = k_idxs.shape
|
||||
rank = len(idx_shape)
|
||||
gather_index_shape = idx_shape + (1,)
|
||||
gather_indices = []
|
||||
for i in range(rank-1):
|
||||
_iota = iota(k_idxs.dtype, idx_shape[i])
|
||||
_iota = tie_in(operand, _iota)
|
||||
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
|
||||
gather_indices.append(_iota)
|
||||
gather_indices.append(reshape(k_idxs, gather_index_shape))
|
||||
gather_indices = concatenate(gather_indices, dimension=rank)
|
||||
slice_sizes = (1,) * rank
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=(),
|
||||
collapsed_slice_dims=tuple(range(rank)),
|
||||
start_index_map=tuple(range(rank)))
|
||||
tangents_out = (gather(tangent, gather_indices, dnums, slice_sizes),
|
||||
ad_util.zero)
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
if bdim == operand.ndim-1:
|
||||
perm = onp.arange(operand.ndim)
|
||||
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
|
||||
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
|
||||
return (transpose(top_k_v, perm),
|
||||
transpose(top_k_i, perm)), (bdim, bdim)
|
||||
else:
|
||||
return top_k(operand, k=k), (bdim, bdim)
|
||||
|
||||
top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
|
||||
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
||||
xla.translations[top_k_p] = partial(standard_translate, 'top_k')
|
||||
|
||||
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
||||
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
||||
|
||||
def _tie_in_transpose_rule(t):
|
||||
return [ad_util.zero, t]
|
||||
|
@ -634,6 +634,13 @@ def rand_int(low, high=None):
|
||||
return randint(low, high=high, size=shape, dtype=dtype)
|
||||
return fn
|
||||
|
||||
def rand_unique_int():
|
||||
randchoice = npr.RandomState(0).choice
|
||||
def fn(shape, dtype):
|
||||
return randchoice(onp.arange(onp.prod(shape), dtype=dtype),
|
||||
size=shape, replace=False)
|
||||
return fn
|
||||
|
||||
def rand_bool():
|
||||
rng = npr.RandomState(0)
|
||||
def generator(shape, dtype):
|
||||
|
@ -2487,6 +2487,22 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
|
||||
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_k={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
|
||||
for dtype in [onp.float32,]
|
||||
for shape in [(4,), (5, 5), (2, 1, 4)]
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTopKGrad(self, shape, dtype, k, rng_factory):
|
||||
rng = rng_factory()
|
||||
perm_rng = onp.random.RandomState(0)
|
||||
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
|
||||
values = perm_rng.permutation(flat_values).reshape(shape)
|
||||
fun = lambda vs: lax.top_k(vs, k=k)[0]
|
||||
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
|
||||
@ -3220,6 +3236,27 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
out_shape = lax.broadcast_shapes(shape1, shape2)
|
||||
self.assertTrue(all(type(s) is int for s in out_shape))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_k={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k, bdims),
|
||||
"shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory}
|
||||
for shape in [(4,), (3, 4, 5)]
|
||||
for k in [1, 3]
|
||||
for bdims in all_bdims(shape)
|
||||
# TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
|
||||
# The top_k indices for integer arrays with identical entries won't match between
|
||||
# vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
|
||||
for dtype, rng_factory in itertools.chain(
|
||||
zip(float_dtypes, itertools.repeat(jtu.rand_default)),
|
||||
zip(int_dtypes, itertools.repeat(jtu.rand_unique_int)))))
|
||||
def testTopK(self, shape, dtype, k, bdims, rng_factory):
|
||||
rng = rng_factory()
|
||||
# _CheckBatching doesn't work with tuple outputs, so test outputs separately.
|
||||
op1 = lambda x: lax.top_k(x, k=k)[0]
|
||||
self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
|
||||
op2 = lambda x: lax.top_k(x, k=k)[1]
|
||||
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
# TODO Concatenate
|
||||
# TODO Reverse
|
||||
# TODO DynamicSlice
|
||||
|
Loading…
x
Reference in New Issue
Block a user