mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #25459 from hawkinsp:sort
PiperOrigin-RevId: 705869484
This commit is contained in:
commit
5a3fa500b5
@ -5506,15 +5506,26 @@ def _operands_to_keys(*operands, num_keys=1):
|
||||
|
||||
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
||||
shape = primals[0].shape
|
||||
iotas = []
|
||||
for dim, size in enumerate(shape):
|
||||
iotas.append(broadcasted_iota(np.int64, shape, dim))
|
||||
sorted_primals_and_idx = sort_p.bind(
|
||||
*primals, iotas[dimension], dimension=dimension,
|
||||
is_stable=is_stable, num_keys=num_keys)
|
||||
idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i]
|
||||
for i in range(len(shape)))
|
||||
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
||||
*primals, broadcasted_iota(np.uint64, shape, dimension),
|
||||
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
|
||||
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
|
||||
dimension))
|
||||
dnums = slicing.GatherDimensionNumbers(
|
||||
offset_dims=(),
|
||||
collapsed_slice_dims=(dimension,),
|
||||
start_index_map=(dimension,),
|
||||
operand_batching_dims=batch_dims,
|
||||
start_indices_batching_dims=batch_dims,
|
||||
)
|
||||
idx = expand_dims(sorted_primals_and_idx[-1], (len(shape),))
|
||||
gather_idx = partial(
|
||||
slicing.gather,
|
||||
start_indices=idx, dimension_numbers=dnums, slice_sizes=(1,) * len(shape),
|
||||
mode=slicing.GatherScatterMode.PROMISE_IN_BOUNDS
|
||||
)
|
||||
tangents_out = [t if type(t) is ad_util.Zero else gather_idx(t)
|
||||
for t in tangents]
|
||||
return tuple(sorted_primals_and_idx[:-1]), tangents_out
|
||||
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
||||
|
@ -834,7 +834,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in [(5,), (5, 7)]
|
||||
for shape in [(5,), (5, 7), (4, 9, 3)]
|
||||
for axis in [len(shape) - 1]
|
||||
],
|
||||
dtype=[np.float32],
|
||||
@ -849,7 +849,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in [(3,), (5, 3)]
|
||||
for shape in [(3,), (5, 3), (4, 9, 3)]
|
||||
for axis in [len(shape) - 1]
|
||||
],
|
||||
key_dtype=[np.float32],
|
||||
|
Loading…
x
Reference in New Issue
Block a user