#jax Optimize jax.numpy.take_along_axis along the dimension satisfies

* the dimension is not the one along which to take values
* the dimension size of input tensor is 1
* the dimension size of the indices is not 1

Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly.

In the following example,
```
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
```
It lowers to the following module before this change,
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %3 = stablehlo.compare  LT, %0, %2,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36)
    %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37)
    %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38)
    %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39)
    %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39)
    %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40)
    %10 = stablehlo.compare  GE, %8, %9,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34)
    %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41)
    %13 = stablehlo.compare  LE, %8, %12,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41)
    %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39)
    %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39)
    %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34)
    %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37)
    return %19 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

With this change, we have
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.compare  LT, %0, %1,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35)
    %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37)
    %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38)
    %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38)
    %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39)
    %9 = stablehlo.compare  GE, %7, %8,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39)
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41)
    %12 = stablehlo.compare  LE, %7, %11,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41)
    %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38)
    %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38)
    %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40)
    %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36)
    return %18 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

PiperOrigin-RevId: 725506779
This commit is contained in:
Zixuan Jiang 2025-02-11 00:08:08 -08:00 committed by jax authors
parent 1e447c8ad2
commit 4b1400dbb9
2 changed files with 24 additions and 10 deletions

View File

@ -10835,6 +10835,12 @@ def take_along_axis(
collapsed_slice_dims = []
operand_batching_dims = []
start_indices_batching_dims = []
# We will squeeze the array. i is the index of the unsqueezed shape, while
# new_i is the index of the squeezed shape. j is the index of the gather
# indices.
dims_to_squeeze = []
new_i = 0
j = 0
for i in range(rank):
if i == axis_int:
@ -10842,22 +10848,20 @@ def take_along_axis(
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
start_index_map.append(new_i)
collapsed_slice_dims.append(new_i)
new_i += 1
j += 1
elif core.definitely_equal(idx_shape[i], 1):
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
# and avoid forming an iota index.
offset_dims.append(i)
slice_sizes.append(arr_shape[i])
new_i += 1
elif core.definitely_equal(arr_shape[i], 1):
# If the array dimension is 1 but the index dimension is not, we
# broadcast the array dimension to the index dimension by repeatedly
# gathering the first element.
gather_indices.append(zeros(gather_index_shape, dtype=index_dtype))
slice_sizes.append(1)
start_index_map.append(i)
collapsed_slice_dims.append(i)
# If the array dimension is 1 but the index dimension is not, we will
# squeeze this dimension.
dims_to_squeeze.append(i)
j += 1
else:
# Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both
@ -10866,10 +10870,13 @@ def take_along_axis(
slice_sizes.append(0)
else:
slice_sizes.append(1)
operand_batching_dims.append(i)
operand_batching_dims.append(new_i)
start_indices_batching_dims.append(j)
new_i += 1
j += 1
# Squeeze a to remove singleton dimensions.
a = lax.squeeze(a, dims_to_squeeze)
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(offset_dims),

View File

@ -4721,6 +4721,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
q1 = np.take_along_axis( h, g, axis=-1)
np.testing.assert_equal(q0, q1)
def testTakeAlongAxisInputTensorHasSingletonDimension(self):
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
q1 = np.take_along_axis( h, g, axis=-2)
np.testing.assert_equal(q0, q1)
def testTakeAlongAxisOutOfBounds(self):
x = jnp.arange(10, dtype=jnp.float32)
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])