From 4b1400dbb9d85f987691f654a272aa8408df6e6f Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 11 Feb 2025 00:08:08 -0800 Subject: [PATCH] #jax Optimize `jax.numpy.take_along_axis` along the dimension satisfies * the dimension is not the one along which to take values * the dimension size of input tensor is 1 * the dimension size of the indices is not 1 Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly. In the following example, ``` h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32) g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-2) ``` It lowers to the following module before this change, ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2x3x5x11x1xi32> loc(#loc35) %3 = stablehlo.compare LT, %0, %2, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35) %c_0 = stablehlo.constant dense<7> : tensor loc(#loc32) %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<2x3x5x11x1xi32> loc(#loc36) %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37) %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38) %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39) %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor loc(#loc39) %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2x3x5x11x2xi64> loc(#loc40) %10 = stablehlo.compare GE, %8, %9, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40) %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34) %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41) %13 = stablehlo.compare LE, %8, %12, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41) %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42) %c_3 = stablehlo.constant dense : tensor loc(#loc43) %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor) -> tensor<2x3x5x11xi1> loc(#loc43) %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39) %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34) %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc39) %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x5x11x13xf32> loc(#loc34) %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37) return %19 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` With this change, we have ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.compare LT, %0, %1, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34) %c_0 = stablehlo.constant dense<7> : tensor loc(#loc32) %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<2x3x5x11x1xi32> loc(#loc35) %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35) %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37) %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38) %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor loc(#loc38) %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2x3x5x11x1xi64> loc(#loc39) %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39) %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40) %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41) %12 = stablehlo.compare LE, %7, %11, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41) %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42) %c_3 = stablehlo.constant dense : tensor loc(#loc43) %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor) -> tensor<2x3x5x11xi1> loc(#loc43) %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38) %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40) %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc38) %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x5x11x13xf32> loc(#loc40) %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36) return %18 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` PiperOrigin-RevId: 725506779 --- jax/_src/numpy/lax_numpy.py | 27 +++++++++++++++++---------- tests/lax_numpy_test.py | 7 +++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0031341ab..cc56e9402 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10835,6 +10835,12 @@ def take_along_axis( collapsed_slice_dims = [] operand_batching_dims = [] start_indices_batching_dims = [] + + # We will squeeze the array. i is the index of the unsqueezed shape, while + # new_i is the index of the squeezed shape. j is the index of the gather + # indices. + dims_to_squeeze = [] + new_i = 0 j = 0 for i in range(rank): if i == axis_int: @@ -10842,22 +10848,20 @@ def take_along_axis( indices = _normalize_index(indices, axis_size) gather_indices.append(lax.reshape(indices, gather_index_shape)) slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + start_index_map.append(new_i) + collapsed_slice_dims.append(new_i) + new_i += 1 j += 1 elif core.definitely_equal(idx_shape[i], 1): # If idx_shape[i] == 1, we can just take the entirety of the arr's axis # and avoid forming an iota index. offset_dims.append(i) slice_sizes.append(arr_shape[i]) + new_i += 1 elif core.definitely_equal(arr_shape[i], 1): - # If the array dimension is 1 but the index dimension is not, we - # broadcast the array dimension to the index dimension by repeatedly - # gathering the first element. - gather_indices.append(zeros(gather_index_shape, dtype=index_dtype)) - slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + # If the array dimension is 1 but the index dimension is not, we will + # squeeze this dimension. + dims_to_squeeze.append(i) j += 1 else: # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both @@ -10866,10 +10870,13 @@ def take_along_axis( slice_sizes.append(0) else: slice_sizes.append(1) - operand_batching_dims.append(i) + operand_batching_dims.append(new_i) start_indices_batching_dims.append(j) + new_i += 1 j += 1 + # Squeeze a to remove singleton dimensions. + a = lax.squeeze(a, dims_to_squeeze) gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 1d6cfad8a..45698d3b1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4721,6 +4721,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): q1 = np.take_along_axis( h, g, axis=-1) np.testing.assert_equal(q0, q1) + def testTakeAlongAxisInputTensorHasSingletonDimension(self): + h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32) + g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8) + q0 = jnp.take_along_axis(h, g, axis=-2) + q1 = np.take_along_axis( h, g, axis=-2) + np.testing.assert_equal(q0, q1) + def testTakeAlongAxisOutOfBounds(self): x = jnp.arange(10, dtype=jnp.float32) idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])