Generalize lax.sort to support lexicographic sorts. (#3709)

This commit is contained in:
Jake Vanderplas 2020-07-09 20:05:19 -07:00 committed by GitHub
parent 0a6b715cd4
commit 804e449389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 24 deletions

View File

@ -1217,26 +1217,43 @@ def cummin(operand: Array, axis: int) -> Array:
return cummin_p.bind(operand, axis=int(axis))
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True) -> Union[Array, Tuple[Array, ...]]:
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
operator.
Args:
operand : Array or sequence of arrays
dimension : integer dimension along which to sort. Default: -1.
is_stable : boolean specifying whether to use a stable sort. Default: True.
num_keys : number of operands to treat as sort keys. Default: 1.
For num_keys > 1, the sort order will be determined lexicographically using
the first `num_keys` arrays, with the first key being primary.
The remaining operands will be returned with the same permutation.
Returns:
operand : sorted version of the input or inputs.
"""
if isinstance(operand, Sequence):
if len(operand) == 0:
raise TypeError("Sort requires at least one operand")
if not (1 <= num_keys <= len(operand)):
raise ValueError(f"num_keys={num_keys} must be between 1 and len(operand)={len(operand)}")
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
return tuple(sort_p.bind(*operand, dimension=dimension,
is_stable=is_stable))
is_stable=is_stable,
comparator=partial(_sort_lt_comparator, num_keys=num_keys)))
else:
if num_keys != 1:
raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
dimension = _canonicalize_axis(dimension, len(operand.shape))
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable)[0]
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)[0]
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
is_stable: bool = True) -> Tuple[Array, Array]:
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
dimension = _canonicalize_axis(dimension, len(keys.shape))
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable)
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)
return k, v
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
@ -5057,26 +5074,29 @@ def _float_to_int_for_sort(x):
sub(unsigned_dtype(onp.iinfo(signed_dtype).max), unsigned), signed_dtype)
return select(lt(signed, _zero(signed)), flipped, signed)
# Default comparator that sorts the operands only on their first arguments.
# Default comparator that sorts the operands lexicographically on the
# first `num_keys` arguments.
# For floating point types, a total order is created where
# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN.
# For complex types, the (real, imag) pairs are sorted lexicographically
# (following NumPy's semantics).
# This code adds complex-number support to the algorithm from:
# This code adds complex-number support and lexicographic ordering to the algorithm from:
# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
def _sort_lt_comparator(*operands):
def _sort_lt_comparator(*operands, num_keys=1):
assert len(operands) >= 2 and len(operands) % 2 == 0, operands
x, y = operands[:2]
assert x.dtype == y.dtype, (x.dtype, y.dtype)
if onp.issubdtype(x.dtype, onp.complexfloating):
x_keys = [_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))]
y_keys = [_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))]
elif onp.issubdtype(x.dtype, onp.floating):
x_keys = [_float_to_int_for_sort(x)]
y_keys = [_float_to_int_for_sort(y)]
else:
x_keys = [x]
y_keys = [y]
assert len(operands) // 2 >= num_keys, (operands, num_keys)
x_keys, y_keys = [], []
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
assert x.dtype == y.dtype, (x.dtype, y.dtype)
if onp.issubdtype(x.dtype, onp.complexfloating):
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
elif onp.issubdtype(x.dtype, onp.floating):
x_keys.append(_float_to_int_for_sort(x))
y_keys.append(_float_to_int_for_sort(y))
else:
x_keys.append(x)
y_keys.append(y)
p = None
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
@ -5084,32 +5104,33 @@ def _sort_lt_comparator(*operands):
else lt(xk, yk))
return p
def _sort_translation_rule(c, *operands, dimension, is_stable):
def _sort_translation_rule(c, *operands, dimension, is_stable, comparator):
types = [c.get_shape(x).xla_element_type() for x in operands]
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
for i, typ in enumerate(types) for j in range(2)]
result = xla.lower_fun(_sort_lt_comparator,
result = xla.lower_fun(comparator,
multiple_results=False)(subc, *params)
comparator = subc.build(result)
out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
comparator=comparator)
return out if len(operands) != 1 else xops.Tuple(c, [out])
def _sort_jvp(primals, tangents, *, dimension, is_stable):
def _sort_jvp(primals, tangents, *, dimension, is_stable, comparator):
shape = primals[0].shape
iotas = []
for dim, size in enumerate(shape):
dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64
iotas.append(broadcasted_iota(dtype, shape, dim))
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
is_stable=is_stable)
is_stable=is_stable, comparator=comparator)
idx = tuple(primals[-1] if i == dimension else iotas[i]
for i in range(len(shape)))
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
return tuple(primals[:-1]), tangents_out
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable):
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, comparator):
prototype_arg, new_bdim = next(
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
new_args = []
@ -5121,7 +5142,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable):
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
new_dimension = dimension + (new_bdim <= dimension)
bdims = (new_bdim,) * len(new_args)
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable),
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, comparator=comparator),
bdims)

View File

@ -1440,6 +1440,26 @@ class LaxTest(jtu.JaxTestCase):
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_num_keys={}".format(
jtu.format_shape_dtype_string(shape, dtype), num_keys),
"shape": shape, "dtype": dtype, "num_keys": num_keys}
for dtype in all_dtypes
for shape in [(3, 5,), (4, 3)]
for num_keys in range(1, shape[0] + 1)))
def testSortNumKeys(self, shape, dtype, num_keys):
# TODO(b/141131288): enable complex-valued sorts on TPU.
if (onp.issubdtype(dtype, onp.complexfloating) and (
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
jtu.device_under_test() == "tpu")):
raise SkipTest("Complex-valued sort not implemented")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
lax_fun = lambda x: lax.sort(tuple(x), num_keys=num_keys)
numpy_fun = lambda x: tuple(x[:, onp.lexsort(x[:num_keys][::-1])])
# self._CompileAndCheck(lax_fun, args_maker)
self._CheckAgainstNumpy(lax_fun, numpy_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, key_dtype),