lexicographic sort_p: accept num_keys rather than comparator (#3715)

This commit is contained in:
Jake Vanderplas 2020-07-10 09:58:35 -07:00 committed by GitHub
parent d2f9c46a0c
commit 60d852773e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1242,18 +1242,18 @@ def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
return tuple(sort_p.bind(*operand, dimension=dimension,
is_stable=is_stable,
comparator=partial(_sort_lt_comparator, num_keys=num_keys)))
num_keys=num_keys))
else:
if num_keys != 1:
raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
dimension = _canonicalize_axis(dimension, len(operand.shape))
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)[0]
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
is_stable: bool = True) -> Tuple[Array, Array]:
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
dimension = _canonicalize_axis(dimension, len(keys.shape))
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
return k, v
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
@ -5105,32 +5105,32 @@ def _sort_lt_comparator(*operands, num_keys=1):
return p
def _sort_translation_rule(c, *operands, dimension, is_stable, comparator):
def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys):
types = [c.get_shape(x).xla_element_type() for x in operands]
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
for i, typ in enumerate(types) for j in range(2)]
result = xla.lower_fun(comparator,
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
multiple_results=False)(subc, *params)
comparator = subc.build(result)
out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
comparator=comparator)
return out if len(operands) != 1 else xops.Tuple(c, [out])
def _sort_jvp(primals, tangents, *, dimension, is_stable, comparator):
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
shape = primals[0].shape
iotas = []
for dim, size in enumerate(shape):
dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64
iotas.append(broadcasted_iota(dtype, shape, dim))
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
is_stable=is_stable, comparator=comparator)
is_stable=is_stable, num_keys=num_keys)
idx = tuple(primals[-1] if i == dimension else iotas[i]
for i in range(len(shape)))
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
return tuple(primals[:-1]), tangents_out
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, comparator):
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
prototype_arg, new_bdim = next(
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
new_args = []
@ -5142,7 +5142,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, comparat
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
new_dimension = dimension + (new_bdim <= dimension)
bdims = (new_bdim,) * len(new_args)
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, comparator=comparator),
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
bdims)