From 124e123946d09661b6b28dcd82a618be93cb132c Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 3 Feb 2025 22:05:08 -0800 Subject: [PATCH] [Pallas] Support promise_in_bounds mode in jnp.take_along_axis. Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds". PiperOrigin-RevId: 722930215 --- jax/_src/numpy/lax_numpy.py | 3 ++- jax/_src/pallas/mosaic/lowering.py | 6 +++++- tests/pallas/tpu_ops_test.py | 10 +++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8403b8876..e00b5b3e8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -11649,7 +11649,8 @@ def take_along_axis( j = 0 for i in range(rank): if i == axis_int: - indices = _normalize_index(indices, axis_size) + if mode != 'promise_in_bounds': + indices = _normalize_index(indices, axis_size) gather_indices.append(lax.reshape(indices, gather_index_shape)) slice_sizes.append(1) start_index_map.append(i) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 92185cc2c..908b47604 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2129,7 +2129,11 @@ def _gather_lowering_rule( slice_sizes == (1, 1) and not unique_indices and not indices_are_sorted - and mode == lax.GatherScatterMode.FILL_OR_DROP + and mode + in ( + lax.GatherScatterMode.FILL_OR_DROP, + lax.GatherScatterMode.PROMISE_IN_BOUNDS, + ) ): if dimension_numbers == lax.GatherDimensionNumbers( offset_dims=(), diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index df2d5cb67..73170aa7a 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -389,10 +389,10 @@ class OpsTest(PallasBaseTest): ref = jax.jit(lambda x: round_fn(x).astype(target))(x) np.testing.assert_array_equal(out, ref) - @parameterized.product(axis=[0, 1]) - def test_dynamic_gather_along_axis(self, axis): - if not jtu.if_cloud_tpu_at_least(2025, 2, 3): - self.skipTest("Requires libtpu built after 2025-02-03") + @parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None]) + def test_dynamic_gather_along_axis(self, axis, mode): + if not jtu.if_cloud_tpu_at_least(2025, 2, 5): + self.skipTest("Requires libtpu built after 2025-02-05") if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or ( axis == 1 and not jtu.is_device_tpu_at_least(version=4) ): @@ -401,7 +401,7 @@ class OpsTest(PallasBaseTest): shape = (8, 128) def kernel(x, indices, out): - out[...] = jnp.take_along_axis(x[...], indices[...], axis) + out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) idx = jax.random.randint(