mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds". PiperOrigin-RevId: 722930215
This commit is contained in:
parent
654a2f6e61
commit
124e123946
@ -11649,7 +11649,8 @@ def take_along_axis(
|
||||
j = 0
|
||||
for i in range(rank):
|
||||
if i == axis_int:
|
||||
indices = _normalize_index(indices, axis_size)
|
||||
if mode != 'promise_in_bounds':
|
||||
indices = _normalize_index(indices, axis_size)
|
||||
gather_indices.append(lax.reshape(indices, gather_index_shape))
|
||||
slice_sizes.append(1)
|
||||
start_index_map.append(i)
|
||||
|
@ -2129,7 +2129,11 @@ def _gather_lowering_rule(
|
||||
slice_sizes == (1, 1)
|
||||
and not unique_indices
|
||||
and not indices_are_sorted
|
||||
and mode == lax.GatherScatterMode.FILL_OR_DROP
|
||||
and mode
|
||||
in (
|
||||
lax.GatherScatterMode.FILL_OR_DROP,
|
||||
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
|
||||
)
|
||||
):
|
||||
if dimension_numbers == lax.GatherDimensionNumbers(
|
||||
offset_dims=(),
|
||||
|
@ -389,10 +389,10 @@ class OpsTest(PallasBaseTest):
|
||||
ref = jax.jit(lambda x: round_fn(x).astype(target))(x)
|
||||
np.testing.assert_array_equal(out, ref)
|
||||
|
||||
@parameterized.product(axis=[0, 1])
|
||||
def test_dynamic_gather_along_axis(self, axis):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 2, 3):
|
||||
self.skipTest("Requires libtpu built after 2025-02-03")
|
||||
@parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None])
|
||||
def test_dynamic_gather_along_axis(self, axis, mode):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 2, 5):
|
||||
self.skipTest("Requires libtpu built after 2025-02-05")
|
||||
if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or (
|
||||
axis == 1 and not jtu.is_device_tpu_at_least(version=4)
|
||||
):
|
||||
@ -401,7 +401,7 @@ class OpsTest(PallasBaseTest):
|
||||
shape = (8, 128)
|
||||
|
||||
def kernel(x, indices, out):
|
||||
out[...] = jnp.take_along_axis(x[...], indices[...], axis)
|
||||
out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode)
|
||||
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
idx = jax.random.randint(
|
||||
|
Loading…
x
Reference in New Issue
Block a user