Make lax.sort support tuple arguments using a variadic sort. (#3085)

* Make lax.sort support tuple arguments using a variadic sort.

Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly.

Remove sort_key_val_p, since it is redundant with a variadic sort_p.

* Fix mypy errors.

* Change JVP rule to use NumPy indexing.
Remove redundant case in batching rule.
This commit is contained in:
Peter Hawkins 2020-05-14 11:13:15 -04:00 committed by GitHub
parent 16cf845148
commit 4ce2aa2563
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 103 deletions

View File

@ -248,7 +248,6 @@ from .lax import (
slice_p,
sort,
sort_key_val,
sort_key_val_p,
sort_p,
sqrt,
sqrt_p,

View File

@ -1176,20 +1176,27 @@ def cumprod(operand: Array, axis: int) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))
def sort(operand: Array, dimension: int = -1) -> Array:
def sort(operand: Union[Array, Tuple[Array, ...]], dimension: int = -1
) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
operator.
"""
return sort_p.bind(operand, dimension=dimension)
if isinstance(operand, tuple):
if len(operand) == 0:
raise TypeError("Sort requires at least one operand")
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
return tuple(sort_p.bind(*operand, dimension=dimension))
else:
dimension = _canonicalize_axis(dimension, len(operand.shape))
return sort_p.bind(operand, dimension=dimension)[0]
def sort_key_val(keys: Array, values: Array,
dimension: int = -1) -> Tuple[Array, Array]:
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
# TODO(mattjj): new sort_key_val is variadic
result = sort_key_val_p.bind(keys, values, dimension=dimension)
sorted_keys, sorted_values = result
return sorted_keys, sorted_values
dimension = _canonicalize_axis(dimension, len(keys.shape))
k, v = sort_p.bind(keys, values, dimension=dimension)
return k, v
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
@ -4548,100 +4555,54 @@ xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
multiple_results=False)
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
sort_shape = lambda operand, dimension: operand.shape
def _sort_jvp_rule(g, operand, *, dimension):
_, g_out = sort_key_val(operand, g, dimension)
return g_out
def _sort_abstract_eval(*args, **kwargs):
args = tuple(raise_to_shaped(arg) for arg in args)
if any(arg.shape != args[0].shape for arg in args[1:]):
shapes = " ".join(str(a.shape) for a in args)
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
return args
def _sort_translation_rule(c, *operands, dimension):
out = xops.Sort(c, operands, dimension=dimension, is_stable=True)
return out if len(operands) != 1 else xops.Tuple(c, [out])
def _sort_jvp(primals, tangents, *, dimension):
shape = primals[0].shape
iotas = []
for dim, size in enumerate(shape):
dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64
iotas.append(broadcasted_iota(dtype, shape, dim))
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension)
idx = tuple(primals[-1] if i == dimension else iotas[i]
for i in range(len(shape)))
tangents_out = tuple(ad_util.zero if t is ad_util.zero else t[idx]
for t in tangents)
return tuple(primals[:-1]), tangents_out
def _sort_batch_rule(batched_args, batch_dims, *, dimension):
operand, = batched_args
bdim, = batch_dims
dimension = dimension % (operand.ndim - 1)
new_dimension = dimension + (bdim <= dimension)
return sort(operand, dimension=new_dimension), bdim
prototype_arg, new_bdim = next(
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
new_args = []
for arg, bdim in zip(batched_args, batch_dims):
if bdim is None:
dims = onp.delete(onp.arange(prototype_arg.ndim), new_bdim)
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
else:
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
new_dimension = dimension + (new_bdim <= dimension)
bdims = (new_bdim,) * len(new_args)
return sort_p.bind(*new_args, dimension=new_dimension), bdims
def _sort_translation_rule(c, operand, *, dimension):
return xops.Sort(c, [operand], dimension=dimension, is_stable=True)
sort_p = standard_primitive(sort_shape, _input_dtype, 'sort',
translation_rule=_sort_translation_rule)
ad.defjvp(sort_p, _sort_jvp_rule)
sort_p = Primitive('sort')
sort_p.multiple_results = True
sort_p.def_impl(partial(xla.apply_primitive, sort_p))
sort_p.def_abstract_eval(_sort_abstract_eval)
xla.translations[sort_p] = _sort_translation_rule
ad.primitive_jvps[sort_p] = _sort_jvp
batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_key_val_abstract_eval(keys, values, *, dimension):
return raise_to_shaped(keys), raise_to_shaped(values)
def _sort_key_val_jvp(primals, tangents, *, dimension):
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
# sort_key_val, or if we could apply a fixed permutation efficiently, we could
# implement this jvp rule with a single sort. The apply_permutation primitive
# would make the jvp (and corresponding transpose rule) faster and easier.
# This would also be cleaner if we didn't get the sorted keys out.
# TODO(mattjj): make sort_key_val variadic, no sorted keys out by default
keys, values = primals
keys_tangents, values_tangents = tangents
val_out = sort_key_val(keys, values, dimension)
if keys_tangents is ad_util.zero:
keys_tangents_out = ad_util.zero
else:
keys_tangents_out = _sort_jvp_rule(keys_tangents, keys, dimension=dimension)
if values_tangents is ad_util.zero:
values_tangents_out = ad_util.zero
else:
values_tangents_out = _sort_jvp_rule(values_tangents, keys,
dimension=dimension)
tangents_out = keys_tangents_out, values_tangents_out
return val_out, tangents_out
def _sort_key_val_transpose_rule(t, keys, values, *, dimension):
t_keys, t_values = t
assert t_keys is ad_util.zero
iota = broadcasted_iota(onp.int32, keys.shape, dimension % keys.ndim)
_, perm = sort_key_val(keys, iota)
keys_result = ad_util.zero if ad.is_undefined_primal(keys) else None
values_result = sort_key_val(perm, t_values)[1] if ad.is_undefined_primal(values) else None
return [keys_result, values_result]
def _sort_key_val_batch_rule(batched_args, batch_dims, *, dimension):
keys, values = batched_args
keys_bdim, values_bdim = batch_dims
assert keys_bdim is not None or values_bdim is not None
if keys_bdim == values_bdim:
new_dimension = dimension + (keys_bdim <= dimension)
return sort_key_val(keys, values, new_dimension), (keys_bdim, keys_bdim)
elif keys_bdim is not None and values_bdim is not None:
keys_trans = batching.moveaxis(keys, keys_bdim, values_bdim)
new_dimension = dimension + (values_bdim <= dimension)
return sort_key_val(keys_trans, values, new_dimension), (values_bdim, values_bdim)
elif keys_bdim is None:
broadcast_dimensions = onp.delete(onp.arange(values.ndim), values_bdim)
new_keys = broadcast_in_dim(keys, values.shape, broadcast_dimensions)
new_dimension = dimension + (values_bdim <= dimension)
return sort_key_val(new_keys, values, new_dimension), (values_bdim, values_bdim)
elif values_bdim is None:
broadcast_dimensions = onp.delete(onp.arange(keys.ndim), keys_bdim)
new_values = broadcast_in_dim(values, keys.shape, broadcast_dimensions)
new_dimension = dimension + (keys_bdim <= dimension)
return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim)
else:
assert False # unreachable
def _sort_key_val_translation_rule(c, keys, values, *, dimension):
return xops.Sort(c, [keys, values], dimension=dimension, is_stable=True)
sort_key_val_p = Primitive('sort_key_val')
sort_key_val_p.multiple_results = True
sort_key_val_p.def_impl(partial(xla.apply_primitive, sort_key_val_p))
sort_key_val_p.def_abstract_eval(_sort_key_val_abstract_eval)
xla.translations[sort_key_val_p] = _sort_key_val_translation_rule
ad.primitive_jvps[sort_key_val_p] = _sort_key_val_jvp
ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
def _top_k_abstract_eval(operand, *, k):
if k < 0:
@ -5205,7 +5166,6 @@ def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
def _check_user_dtype_supported(dtype, fun_name=None):
onp_dtype = onp.dtype(dtype)
if onp_dtype.kind not in "biufc" and onp_dtype.type != dtypes.bfloat16:
@ -5220,3 +5180,15 @@ def _check_user_dtype_supported(dtype, fun_name=None):
fun_name = "requested in {}".format(fun_name) if fun_name else ""
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
def _canonicalize_axis(axis, num_dims):
"""Canonicalize an axis in (-num_dims, num_dims) to [0, num_dims)."""
axis = int(axis)
if axis < 0:
axis = axis + num_dims
if axis < 0 or axis >= num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
return axis

View File

@ -2952,9 +2952,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
raise ValueError("'order' argument to sort is not supported.")
if axis is None:
return lax.sort(a.ravel(), 0)
return lax.sort(a.ravel(), dimension=0)
else:
return lax.sort(a, _canonicalize_axis(axis, ndim(a)))
return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a)))
@_wraps(np.argsort)

View File

@ -1310,7 +1310,7 @@ class LaxTest(jtu.JaxTestCase):
def testSort(self, shape, dtype, axis, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
fun = lambda x: lax.sort(x, axis)
fun = lambda x: lax.sort(x, dimension=axis)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1324,7 +1324,7 @@ class LaxTest(jtu.JaxTestCase):
def testSortAgainstNumpy(self, shape, dtype, axis, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.sort(x, axis)
op = lambda x: lax.sort(x, dimension=axis)
numpy_op = lambda x: lax_reference.sort(x, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@ -2442,7 +2442,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
def testSortGrad(self, shape, dtype, axis, rng_factory):
rng = rng_factory(self.rng())
operand = rng(shape, dtype)
sort = lambda x: lax.sort(x, axis)
sort = lambda x: lax.sort(x, dimension=axis)
check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
# TODO(b/205052657): enable more tests when supported
@ -2678,7 +2678,8 @@ class LaxVmapTest(jtu.JaxTestCase):
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
rtol=None, atol=None):
batched_shapes = list(map(partial(add_bdim, bdim_size), bdims, shapes))
batched_shapes = list(jax.util.safe_map(partial(add_bdim, bdim_size),
bdims, shapes))
args = [rng(shape, dtype)
for shape, dtype in jax.util.safe_zip(batched_shapes, dtypes)]
args_slice = args_slicer(args, bdims)
@ -3241,12 +3242,33 @@ class LaxVmapTest(jtu.JaxTestCase):
op2 = lambda x: lax.top_k(x, k=k)[1]
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}"
.format(jtu.format_shape_dtype_string(shape, onp.float32), dimension,
arity, bdims),
"shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims}
for shape in [(2, 3)]
for dimension in [0, 1]
for arity in range(3)
for bdims in all_bdims(*((shape,) * arity))))
def testSort(self, shape, dimension, arity, bdims):
rng = jtu.rand_default(self.rng())
if arity == 1:
fun = partial(lax.sort, dimension=dimension)
self._CheckBatching(fun, 5, bdims, (shape,) * arity, (onp.float32,) * arity,
rng)
else:
for i in range(arity):
fun = lambda *args, i=i: lax.sort(args, dimension=dimension)[i]
self._CheckBatching(fun, 5, bdims, (shape,) * arity,
(onp.float32,) * arity, rng)
# TODO Concatenate
# TODO Reverse
# TODO DynamicSlice
# TODO DynamicUpdateSlice
# TODO Sort
# TODO SortKeyVal
# TODO Collapse
# TODO Scatter