Merge pull request #25459 from hawkinsp:sort

PiperOrigin-RevId: 705869484
This commit is contained in:
jax authors 2024-12-13 06:55:32 -08:00
commit 5a3fa500b5
2 changed files with 21 additions and 10 deletions

View File

@ -5506,15 +5506,26 @@ def _operands_to_keys(*operands, num_keys=1):
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
shape = primals[0].shape
iotas = []
for dim, size in enumerate(shape):
iotas.append(broadcasted_iota(np.int64, shape, dim))
sorted_primals_and_idx = sort_p.bind(
*primals, iotas[dimension], dimension=dimension,
is_stable=is_stable, num_keys=num_keys)
idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i]
for i in range(len(shape)))
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
*primals, broadcasted_iota(np.uint64, shape, dimension),
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
dimension))
dnums = slicing.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=(dimension,),
start_index_map=(dimension,),
operand_batching_dims=batch_dims,
start_indices_batching_dims=batch_dims,
)
idx = expand_dims(sorted_primals_and_idx[-1], (len(shape),))
gather_idx = partial(
slicing.gather,
start_indices=idx, dimension_numbers=dnums, slice_sizes=(1,) * len(shape),
mode=slicing.GatherScatterMode.PROMISE_IN_BOUNDS
)
tangents_out = [t if type(t) is ad_util.Zero else gather_idx(t)
for t in tangents]
return tuple(sorted_primals_and_idx[:-1]), tangents_out
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):

View File

@ -834,7 +834,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# TODO(b/205052657): enable more tests when supported
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(5,), (5, 7)]
for shape in [(5,), (5, 7), (4, 9, 3)]
for axis in [len(shape) - 1]
],
dtype=[np.float32],
@ -849,7 +849,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# TODO(b/205052657): enable more tests when supported
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(3,), (5, 3)]
for shape in [(3,), (5, 3), (4, 9, 3)]
for axis in [len(shape) - 1]
],
key_dtype=[np.float32],