mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
lexicographic sort_p: accept num_keys rather than comparator (#3715)
This commit is contained in:
parent
d2f9c46a0c
commit
60d852773e
@ -1242,18 +1242,18 @@ def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
|
||||
return tuple(sort_p.bind(*operand, dimension=dimension,
|
||||
is_stable=is_stable,
|
||||
comparator=partial(_sort_lt_comparator, num_keys=num_keys)))
|
||||
num_keys=num_keys))
|
||||
else:
|
||||
if num_keys != 1:
|
||||
raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
|
||||
dimension = _canonicalize_axis(dimension, len(operand.shape))
|
||||
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)[0]
|
||||
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
|
||||
|
||||
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
|
||||
is_stable: bool = True) -> Tuple[Array, Array]:
|
||||
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
|
||||
dimension = _canonicalize_axis(dimension, len(keys.shape))
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, comparator=_sort_lt_comparator)
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
||||
return k, v
|
||||
|
||||
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
||||
@ -5105,32 +5105,32 @@ def _sort_lt_comparator(*operands, num_keys=1):
|
||||
return p
|
||||
|
||||
|
||||
def _sort_translation_rule(c, *operands, dimension, is_stable, comparator):
|
||||
def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys):
|
||||
types = [c.get_shape(x).xla_element_type() for x in operands]
|
||||
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
|
||||
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
||||
for i, typ in enumerate(types) for j in range(2)]
|
||||
result = xla.lower_fun(comparator,
|
||||
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
|
||||
multiple_results=False)(subc, *params)
|
||||
comparator = subc.build(result)
|
||||
out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
|
||||
comparator=comparator)
|
||||
return out if len(operands) != 1 else xops.Tuple(c, [out])
|
||||
|
||||
def _sort_jvp(primals, tangents, *, dimension, is_stable, comparator):
|
||||
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
||||
shape = primals[0].shape
|
||||
iotas = []
|
||||
for dim, size in enumerate(shape):
|
||||
dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64
|
||||
iotas.append(broadcasted_iota(dtype, shape, dim))
|
||||
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
|
||||
is_stable=is_stable, comparator=comparator)
|
||||
is_stable=is_stable, num_keys=num_keys)
|
||||
idx = tuple(primals[-1] if i == dimension else iotas[i]
|
||||
for i in range(len(shape)))
|
||||
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
||||
return tuple(primals[:-1]), tangents_out
|
||||
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, comparator):
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
||||
prototype_arg, new_bdim = next(
|
||||
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
||||
new_args = []
|
||||
@ -5142,7 +5142,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, comparat
|
||||
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
||||
new_dimension = dimension + (new_bdim <= dimension)
|
||||
bdims = (new_bdim,) * len(new_args)
|
||||
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, comparator=comparator),
|
||||
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
|
||||
bdims)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user