mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make lax.sort support tuple arguments using a variadic sort. (#3085)
* Make lax.sort support tuple arguments using a variadic sort. Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly. Remove sort_key_val_p, since it is redundant with a variadic sort_p. * Fix mypy errors. * Change JVP rule to use NumPy indexing. Remove redundant case in batching rule.
This commit is contained in:
parent
16cf845148
commit
4ce2aa2563
@ -248,7 +248,6 @@ from .lax import (
|
||||
slice_p,
|
||||
sort,
|
||||
sort_key_val,
|
||||
sort_key_val_p,
|
||||
sort_p,
|
||||
sqrt,
|
||||
sqrt_p,
|
||||
|
160
jax/lax/lax.py
160
jax/lax/lax.py
@ -1176,20 +1176,27 @@ def cumprod(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
|
||||
def sort(operand: Array, dimension: int = -1) -> Array:
|
||||
def sort(operand: Union[Array, Tuple[Array, ...]], dimension: int = -1
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
|
||||
operator.
|
||||
"""
|
||||
return sort_p.bind(operand, dimension=dimension)
|
||||
if isinstance(operand, tuple):
|
||||
if len(operand) == 0:
|
||||
raise TypeError("Sort requires at least one operand")
|
||||
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
|
||||
return tuple(sort_p.bind(*operand, dimension=dimension))
|
||||
else:
|
||||
dimension = _canonicalize_axis(dimension, len(operand.shape))
|
||||
return sort_p.bind(operand, dimension=dimension)[0]
|
||||
|
||||
def sort_key_val(keys: Array, values: Array,
|
||||
dimension: int = -1) -> Tuple[Array, Array]:
|
||||
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
|
||||
# TODO(mattjj): new sort_key_val is variadic
|
||||
result = sort_key_val_p.bind(keys, values, dimension=dimension)
|
||||
sorted_keys, sorted_values = result
|
||||
return sorted_keys, sorted_values
|
||||
dimension = _canonicalize_axis(dimension, len(keys.shape))
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension)
|
||||
return k, v
|
||||
|
||||
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
||||
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
|
||||
@ -4548,100 +4555,54 @@ xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
|
||||
|
||||
sort_shape = lambda operand, dimension: operand.shape
|
||||
|
||||
def _sort_jvp_rule(g, operand, *, dimension):
|
||||
_, g_out = sort_key_val(operand, g, dimension)
|
||||
return g_out
|
||||
def _sort_abstract_eval(*args, **kwargs):
|
||||
args = tuple(raise_to_shaped(arg) for arg in args)
|
||||
if any(arg.shape != args[0].shape for arg in args[1:]):
|
||||
shapes = " ".join(str(a.shape) for a in args)
|
||||
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
|
||||
return args
|
||||
|
||||
def _sort_translation_rule(c, *operands, dimension):
|
||||
out = xops.Sort(c, operands, dimension=dimension, is_stable=True)
|
||||
return out if len(operands) != 1 else xops.Tuple(c, [out])
|
||||
|
||||
def _sort_jvp(primals, tangents, *, dimension):
|
||||
shape = primals[0].shape
|
||||
iotas = []
|
||||
for dim, size in enumerate(shape):
|
||||
dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64
|
||||
iotas.append(broadcasted_iota(dtype, shape, dim))
|
||||
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension)
|
||||
idx = tuple(primals[-1] if i == dimension else iotas[i]
|
||||
for i in range(len(shape)))
|
||||
tangents_out = tuple(ad_util.zero if t is ad_util.zero else t[idx]
|
||||
for t in tangents)
|
||||
return tuple(primals[:-1]), tangents_out
|
||||
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
dimension = dimension % (operand.ndim - 1)
|
||||
new_dimension = dimension + (bdim <= dimension)
|
||||
return sort(operand, dimension=new_dimension), bdim
|
||||
prototype_arg, new_bdim = next(
|
||||
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
||||
new_args = []
|
||||
for arg, bdim in zip(batched_args, batch_dims):
|
||||
if bdim is None:
|
||||
dims = onp.delete(onp.arange(prototype_arg.ndim), new_bdim)
|
||||
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
|
||||
else:
|
||||
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
||||
new_dimension = dimension + (new_bdim <= dimension)
|
||||
bdims = (new_bdim,) * len(new_args)
|
||||
return sort_p.bind(*new_args, dimension=new_dimension), bdims
|
||||
|
||||
def _sort_translation_rule(c, operand, *, dimension):
|
||||
return xops.Sort(c, [operand], dimension=dimension, is_stable=True)
|
||||
|
||||
sort_p = standard_primitive(sort_shape, _input_dtype, 'sort',
|
||||
translation_rule=_sort_translation_rule)
|
||||
ad.defjvp(sort_p, _sort_jvp_rule)
|
||||
sort_p = Primitive('sort')
|
||||
sort_p.multiple_results = True
|
||||
sort_p.def_impl(partial(xla.apply_primitive, sort_p))
|
||||
sort_p.def_abstract_eval(_sort_abstract_eval)
|
||||
xla.translations[sort_p] = _sort_translation_rule
|
||||
ad.primitive_jvps[sort_p] = _sort_jvp
|
||||
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
|
||||
def _sort_key_val_abstract_eval(keys, values, *, dimension):
|
||||
return raise_to_shaped(keys), raise_to_shaped(values)
|
||||
|
||||
def _sort_key_val_jvp(primals, tangents, *, dimension):
|
||||
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
|
||||
# sort_key_val, or if we could apply a fixed permutation efficiently, we could
|
||||
# implement this jvp rule with a single sort. The apply_permutation primitive
|
||||
# would make the jvp (and corresponding transpose rule) faster and easier.
|
||||
# This would also be cleaner if we didn't get the sorted keys out.
|
||||
# TODO(mattjj): make sort_key_val variadic, no sorted keys out by default
|
||||
keys, values = primals
|
||||
keys_tangents, values_tangents = tangents
|
||||
|
||||
val_out = sort_key_val(keys, values, dimension)
|
||||
|
||||
if keys_tangents is ad_util.zero:
|
||||
keys_tangents_out = ad_util.zero
|
||||
else:
|
||||
keys_tangents_out = _sort_jvp_rule(keys_tangents, keys, dimension=dimension)
|
||||
|
||||
if values_tangents is ad_util.zero:
|
||||
values_tangents_out = ad_util.zero
|
||||
else:
|
||||
values_tangents_out = _sort_jvp_rule(values_tangents, keys,
|
||||
dimension=dimension)
|
||||
|
||||
tangents_out = keys_tangents_out, values_tangents_out
|
||||
return val_out, tangents_out
|
||||
|
||||
def _sort_key_val_transpose_rule(t, keys, values, *, dimension):
|
||||
t_keys, t_values = t
|
||||
assert t_keys is ad_util.zero
|
||||
iota = broadcasted_iota(onp.int32, keys.shape, dimension % keys.ndim)
|
||||
_, perm = sort_key_val(keys, iota)
|
||||
keys_result = ad_util.zero if ad.is_undefined_primal(keys) else None
|
||||
values_result = sort_key_val(perm, t_values)[1] if ad.is_undefined_primal(values) else None
|
||||
return [keys_result, values_result]
|
||||
|
||||
def _sort_key_val_batch_rule(batched_args, batch_dims, *, dimension):
|
||||
keys, values = batched_args
|
||||
keys_bdim, values_bdim = batch_dims
|
||||
assert keys_bdim is not None or values_bdim is not None
|
||||
if keys_bdim == values_bdim:
|
||||
new_dimension = dimension + (keys_bdim <= dimension)
|
||||
return sort_key_val(keys, values, new_dimension), (keys_bdim, keys_bdim)
|
||||
elif keys_bdim is not None and values_bdim is not None:
|
||||
keys_trans = batching.moveaxis(keys, keys_bdim, values_bdim)
|
||||
new_dimension = dimension + (values_bdim <= dimension)
|
||||
return sort_key_val(keys_trans, values, new_dimension), (values_bdim, values_bdim)
|
||||
elif keys_bdim is None:
|
||||
broadcast_dimensions = onp.delete(onp.arange(values.ndim), values_bdim)
|
||||
new_keys = broadcast_in_dim(keys, values.shape, broadcast_dimensions)
|
||||
new_dimension = dimension + (values_bdim <= dimension)
|
||||
return sort_key_val(new_keys, values, new_dimension), (values_bdim, values_bdim)
|
||||
elif values_bdim is None:
|
||||
broadcast_dimensions = onp.delete(onp.arange(keys.ndim), keys_bdim)
|
||||
new_values = broadcast_in_dim(values, keys.shape, broadcast_dimensions)
|
||||
new_dimension = dimension + (keys_bdim <= dimension)
|
||||
return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim)
|
||||
else:
|
||||
assert False # unreachable
|
||||
|
||||
def _sort_key_val_translation_rule(c, keys, values, *, dimension):
|
||||
return xops.Sort(c, [keys, values], dimension=dimension, is_stable=True)
|
||||
|
||||
sort_key_val_p = Primitive('sort_key_val')
|
||||
sort_key_val_p.multiple_results = True
|
||||
sort_key_val_p.def_impl(partial(xla.apply_primitive, sort_key_val_p))
|
||||
sort_key_val_p.def_abstract_eval(_sort_key_val_abstract_eval)
|
||||
xla.translations[sort_key_val_p] = _sort_key_val_translation_rule
|
||||
ad.primitive_jvps[sort_key_val_p] = _sort_key_val_jvp
|
||||
ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
|
||||
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
|
||||
|
||||
def _top_k_abstract_eval(operand, *, k):
|
||||
if k < 0:
|
||||
@ -5205,7 +5166,6 @@ def _abstractify(x):
|
||||
return raise_to_shaped(core.get_aval(x))
|
||||
|
||||
|
||||
|
||||
def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
onp_dtype = onp.dtype(dtype)
|
||||
if onp_dtype.kind not in "biufc" and onp_dtype.type != dtypes.bfloat16:
|
||||
@ -5220,3 +5180,15 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
fun_name = "requested in {}".format(fun_name) if fun_name else ""
|
||||
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
|
||||
|
||||
|
||||
def _canonicalize_axis(axis, num_dims):
|
||||
"""Canonicalize an axis in (-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = int(axis)
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
if axis < 0 or axis >= num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
return axis
|
@ -2952,9 +2952,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
|
||||
raise ValueError("'order' argument to sort is not supported.")
|
||||
|
||||
if axis is None:
|
||||
return lax.sort(a.ravel(), 0)
|
||||
return lax.sort(a.ravel(), dimension=0)
|
||||
else:
|
||||
return lax.sort(a, _canonicalize_axis(axis, ndim(a)))
|
||||
return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a)))
|
||||
|
||||
|
||||
@_wraps(np.argsort)
|
||||
|
@ -1310,7 +1310,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testSort(self, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
fun = lambda x: lax.sort(x, axis)
|
||||
fun = lambda x: lax.sort(x, dimension=axis)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1324,7 +1324,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testSortAgainstNumpy(self, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.sort(x, axis)
|
||||
op = lambda x: lax.sort(x, dimension=axis)
|
||||
numpy_op = lambda x: lax_reference.sort(x, axis)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
|
||||
@ -2442,7 +2442,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
def testSortGrad(self, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
sort = lambda x: lax.sort(x, axis)
|
||||
sort = lambda x: lax.sort(x, dimension=axis)
|
||||
check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@ -2678,7 +2678,8 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
|
||||
rtol=None, atol=None):
|
||||
batched_shapes = list(map(partial(add_bdim, bdim_size), bdims, shapes))
|
||||
batched_shapes = list(jax.util.safe_map(partial(add_bdim, bdim_size),
|
||||
bdims, shapes))
|
||||
args = [rng(shape, dtype)
|
||||
for shape, dtype in jax.util.safe_zip(batched_shapes, dtypes)]
|
||||
args_slice = args_slicer(args, bdims)
|
||||
@ -3241,12 +3242,33 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
op2 = lambda x: lax.top_k(x, k=k)[1]
|
||||
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, onp.float32), dimension,
|
||||
arity, bdims),
|
||||
"shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims}
|
||||
for shape in [(2, 3)]
|
||||
for dimension in [0, 1]
|
||||
for arity in range(3)
|
||||
for bdims in all_bdims(*((shape,) * arity))))
|
||||
def testSort(self, shape, dimension, arity, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if arity == 1:
|
||||
fun = partial(lax.sort, dimension=dimension)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,) * arity, (onp.float32,) * arity,
|
||||
rng)
|
||||
else:
|
||||
for i in range(arity):
|
||||
fun = lambda *args, i=i: lax.sort(args, dimension=dimension)[i]
|
||||
self._CheckBatching(fun, 5, bdims, (shape,) * arity,
|
||||
(onp.float32,) * arity, rng)
|
||||
|
||||
|
||||
# TODO Concatenate
|
||||
# TODO Reverse
|
||||
# TODO DynamicSlice
|
||||
# TODO DynamicUpdateSlice
|
||||
# TODO Sort
|
||||
# TODO SortKeyVal
|
||||
# TODO Collapse
|
||||
# TODO Scatter
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user