mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Generalize lax.sort to support lexicographic sorts. (#3709)
This commit is contained in:
parent
0a6b715cd4
commit
804e449389
@ -1217,26 +1217,43 @@ def cummin(operand: Array, axis: int) -> Array:
|
||||
return cummin_p.bind(operand, axis=int(axis))
|
||||
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True) -> Union[Array, Tuple[Array, ...]]:
|
||||
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
operand : Array or sequence of arrays
|
||||
dimension : integer dimension along which to sort. Default: -1.
|
||||
is_stable : boolean specifying whether to use a stable sort. Default: True.
|
||||
num_keys : number of operands to treat as sort keys. Default: 1.
|
||||
For num_keys > 1, the sort order will be determined lexicographically using
|
||||
the first `num_keys` arrays, with the first key being primary.
|
||||
The remaining operands will be returned with the same permutation.
|
||||
|
||||
Returns:
|
||||
operand : sorted version of the input or inputs.
|
||||
"""
|
||||
if isinstance(operand, Sequence):
|
||||
if len(operand) == 0:
|
||||
raise TypeError("Sort requires at least one operand")
|
||||
if not (1 <= num_keys <= len(operand)):
|
||||
raise ValueError(f"num_keys={num_keys} must be between 1 and len(operand)={len(operand)}")
|
||||
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
|
||||
return tuple(sort_p.bind(*operand, dimension=dimension,
|
||||
is_stable=is_stable))
|
||||
is_stable=is_stable,
|
||||
comparator=partial(_sort_lt_comparator, num_keys=num_keys)))
|
||||
else:
|
||||
if num_keys != 1:
|
||||
raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
|
||||
dimension = _canonicalize_axis(dimension, len(operand.shape))
|
||||
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable)[0]
|
||||
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)[0]
|
||||
|
||||
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
|
||||
is_stable: bool = True) -> Tuple[Array, Array]:
|
||||
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
|
||||
dimension = _canonicalize_axis(dimension, len(keys.shape))
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable)
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)
|
||||
return k, v
|
||||
|
||||
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
||||
@ -5057,26 +5074,29 @@ def _float_to_int_for_sort(x):
|
||||
sub(unsigned_dtype(onp.iinfo(signed_dtype).max), unsigned), signed_dtype)
|
||||
return select(lt(signed, _zero(signed)), flipped, signed)
|
||||
|
||||
# Default comparator that sorts the operands only on their first arguments.
|
||||
# Default comparator that sorts the operands lexicographically on the
|
||||
# first `num_keys` arguments.
|
||||
# For floating point types, a total order is created where
|
||||
# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN.
|
||||
# For complex types, the (real, imag) pairs are sorted lexicographically
|
||||
# (following NumPy's semantics).
|
||||
# This code adds complex-number support to the algorithm from:
|
||||
# This code adds complex-number support and lexicographic ordering to the algorithm from:
|
||||
# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
|
||||
def _sort_lt_comparator(*operands):
|
||||
def _sort_lt_comparator(*operands, num_keys=1):
|
||||
assert len(operands) >= 2 and len(operands) % 2 == 0, operands
|
||||
x, y = operands[:2]
|
||||
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
||||
if onp.issubdtype(x.dtype, onp.complexfloating):
|
||||
x_keys = [_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))]
|
||||
y_keys = [_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))]
|
||||
elif onp.issubdtype(x.dtype, onp.floating):
|
||||
x_keys = [_float_to_int_for_sort(x)]
|
||||
y_keys = [_float_to_int_for_sort(y)]
|
||||
else:
|
||||
x_keys = [x]
|
||||
y_keys = [y]
|
||||
assert len(operands) // 2 >= num_keys, (operands, num_keys)
|
||||
x_keys, y_keys = [], []
|
||||
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
||||
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
||||
if onp.issubdtype(x.dtype, onp.complexfloating):
|
||||
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
|
||||
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
|
||||
elif onp.issubdtype(x.dtype, onp.floating):
|
||||
x_keys.append(_float_to_int_for_sort(x))
|
||||
y_keys.append(_float_to_int_for_sort(y))
|
||||
else:
|
||||
x_keys.append(x)
|
||||
y_keys.append(y)
|
||||
|
||||
p = None
|
||||
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
||||
@ -5084,32 +5104,33 @@ def _sort_lt_comparator(*operands):
|
||||
else lt(xk, yk))
|
||||
return p
|
||||
|
||||
def _sort_translation_rule(c, *operands, dimension, is_stable):
|
||||
|
||||
def _sort_translation_rule(c, *operands, dimension, is_stable, comparator):
|
||||
types = [c.get_shape(x).xla_element_type() for x in operands]
|
||||
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
|
||||
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
||||
for i, typ in enumerate(types) for j in range(2)]
|
||||
result = xla.lower_fun(_sort_lt_comparator,
|
||||
result = xla.lower_fun(comparator,
|
||||
multiple_results=False)(subc, *params)
|
||||
comparator = subc.build(result)
|
||||
out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
|
||||
comparator=comparator)
|
||||
return out if len(operands) != 1 else xops.Tuple(c, [out])
|
||||
|
||||
def _sort_jvp(primals, tangents, *, dimension, is_stable):
|
||||
def _sort_jvp(primals, tangents, *, dimension, is_stable, comparator):
|
||||
shape = primals[0].shape
|
||||
iotas = []
|
||||
for dim, size in enumerate(shape):
|
||||
dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64
|
||||
iotas.append(broadcasted_iota(dtype, shape, dim))
|
||||
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
|
||||
is_stable=is_stable)
|
||||
is_stable=is_stable, comparator=comparator)
|
||||
idx = tuple(primals[-1] if i == dimension else iotas[i]
|
||||
for i in range(len(shape)))
|
||||
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
||||
return tuple(primals[:-1]), tangents_out
|
||||
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable):
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, comparator):
|
||||
prototype_arg, new_bdim = next(
|
||||
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
||||
new_args = []
|
||||
@ -5121,7 +5142,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable):
|
||||
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
||||
new_dimension = dimension + (new_bdim <= dimension)
|
||||
bdims = (new_bdim,) * len(new_args)
|
||||
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable),
|
||||
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, comparator=comparator),
|
||||
bdims)
|
||||
|
||||
|
||||
|
@ -1440,6 +1440,26 @@ class LaxTest(jtu.JaxTestCase):
|
||||
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_num_keys={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), num_keys),
|
||||
"shape": shape, "dtype": dtype, "num_keys": num_keys}
|
||||
for dtype in all_dtypes
|
||||
for shape in [(3, 5,), (4, 3)]
|
||||
for num_keys in range(1, shape[0] + 1)))
|
||||
def testSortNumKeys(self, shape, dtype, num_keys):
|
||||
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
||||
if (onp.issubdtype(dtype, onp.complexfloating) and (
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
||||
jtu.device_under_test() == "tpu")):
|
||||
raise SkipTest("Complex-valued sort not implemented")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
lax_fun = lambda x: lax.sort(tuple(x), num_keys=num_keys)
|
||||
numpy_fun = lambda x: tuple(x[:, onp.lexsort(x[:num_keys][::-1])])
|
||||
# self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CheckAgainstNumpy(lax_fun, numpy_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, key_dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user