[Pallas] Support promise_in_bounds mode in jnp.take_along_axis.

Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
This commit is contained in:
Jevin Jiang 2025-02-03 22:05:08 -08:00 committed by jax authors
parent 654a2f6e61
commit 124e123946
3 changed files with 12 additions and 7 deletions

View File

@ -11649,7 +11649,8 @@ def take_along_axis(
j = 0
for i in range(rank):
if i == axis_int:
indices = _normalize_index(indices, axis_size)
if mode != 'promise_in_bounds':
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)

View File

@ -2129,7 +2129,11 @@ def _gather_lowering_rule(
slice_sizes == (1, 1)
and not unique_indices
and not indices_are_sorted
and mode == lax.GatherScatterMode.FILL_OR_DROP
and mode
in (
lax.GatherScatterMode.FILL_OR_DROP,
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
)
):
if dimension_numbers == lax.GatherDimensionNumbers(
offset_dims=(),

View File

@ -389,10 +389,10 @@ class OpsTest(PallasBaseTest):
ref = jax.jit(lambda x: round_fn(x).astype(target))(x)
np.testing.assert_array_equal(out, ref)
@parameterized.product(axis=[0, 1])
def test_dynamic_gather_along_axis(self, axis):
if not jtu.if_cloud_tpu_at_least(2025, 2, 3):
self.skipTest("Requires libtpu built after 2025-02-03")
@parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None])
def test_dynamic_gather_along_axis(self, axis, mode):
if not jtu.if_cloud_tpu_at_least(2025, 2, 5):
self.skipTest("Requires libtpu built after 2025-02-05")
if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or (
axis == 1 and not jtu.is_device_tpu_at_least(version=4)
):
@ -401,7 +401,7 @@ class OpsTest(PallasBaseTest):
shape = (8, 128)
def kernel(x, indices, out):
out[...] = jnp.take_along_axis(x[...], indices[...], axis)
out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
idx = jax.random.randint(