mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
#jax Optimize jax.numpy.take_along_axis
along the dimension satisfies
* the dimension is not the one along which to take values * the dimension size of input tensor is 1 * the dimension size of the indices is not 1 Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly. In the following example, ``` h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32) g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-2) ``` It lowers to the following module before this change, ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35) %3 = stablehlo.compare LT, %0, %2, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35) %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32) %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36) %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37) %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38) %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39) %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39) %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40) %10 = stablehlo.compare GE, %8, %9, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40) %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34) %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41) %13 = stablehlo.compare LE, %8, %12, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41) %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42) %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43) %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43) %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39) %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34) %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39) %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34) %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37) return %19 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` With this change, we have ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.compare LT, %0, %1, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34) %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32) %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35) %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35) %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37) %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38) %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38) %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39) %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39) %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40) %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41) %12 = stablehlo.compare LE, %7, %11, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41) %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42) %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43) %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43) %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38) %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40) %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38) %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40) %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36) return %18 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` PiperOrigin-RevId: 725506779
This commit is contained in:
parent
1e447c8ad2
commit
4b1400dbb9
@ -10835,6 +10835,12 @@ def take_along_axis(
|
||||
collapsed_slice_dims = []
|
||||
operand_batching_dims = []
|
||||
start_indices_batching_dims = []
|
||||
|
||||
# We will squeeze the array. i is the index of the unsqueezed shape, while
|
||||
# new_i is the index of the squeezed shape. j is the index of the gather
|
||||
# indices.
|
||||
dims_to_squeeze = []
|
||||
new_i = 0
|
||||
j = 0
|
||||
for i in range(rank):
|
||||
if i == axis_int:
|
||||
@ -10842,22 +10848,20 @@ def take_along_axis(
|
||||
indices = _normalize_index(indices, axis_size)
|
||||
gather_indices.append(lax.reshape(indices, gather_index_shape))
|
||||
slice_sizes.append(1)
|
||||
start_index_map.append(i)
|
||||
collapsed_slice_dims.append(i)
|
||||
start_index_map.append(new_i)
|
||||
collapsed_slice_dims.append(new_i)
|
||||
new_i += 1
|
||||
j += 1
|
||||
elif core.definitely_equal(idx_shape[i], 1):
|
||||
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
|
||||
# and avoid forming an iota index.
|
||||
offset_dims.append(i)
|
||||
slice_sizes.append(arr_shape[i])
|
||||
new_i += 1
|
||||
elif core.definitely_equal(arr_shape[i], 1):
|
||||
# If the array dimension is 1 but the index dimension is not, we
|
||||
# broadcast the array dimension to the index dimension by repeatedly
|
||||
# gathering the first element.
|
||||
gather_indices.append(zeros(gather_index_shape, dtype=index_dtype))
|
||||
slice_sizes.append(1)
|
||||
start_index_map.append(i)
|
||||
collapsed_slice_dims.append(i)
|
||||
# If the array dimension is 1 but the index dimension is not, we will
|
||||
# squeeze this dimension.
|
||||
dims_to_squeeze.append(i)
|
||||
j += 1
|
||||
else:
|
||||
# Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both
|
||||
@ -10866,10 +10870,13 @@ def take_along_axis(
|
||||
slice_sizes.append(0)
|
||||
else:
|
||||
slice_sizes.append(1)
|
||||
operand_batching_dims.append(i)
|
||||
operand_batching_dims.append(new_i)
|
||||
start_indices_batching_dims.append(j)
|
||||
new_i += 1
|
||||
j += 1
|
||||
|
||||
# Squeeze a to remove singleton dimensions.
|
||||
a = lax.squeeze(a, dims_to_squeeze)
|
||||
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=tuple(offset_dims),
|
||||
|
@ -4721,6 +4721,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
q1 = np.take_along_axis( h, g, axis=-1)
|
||||
np.testing.assert_equal(q0, q1)
|
||||
|
||||
def testTakeAlongAxisInputTensorHasSingletonDimension(self):
|
||||
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
|
||||
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
|
||||
q0 = jnp.take_along_axis(h, g, axis=-2)
|
||||
q1 = np.take_along_axis( h, g, axis=-2)
|
||||
np.testing.assert_equal(q0, q1)
|
||||
|
||||
def testTakeAlongAxisOutOfBounds(self):
|
||||
x = jnp.arange(10, dtype=jnp.float32)
|
||||
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
|
||||
|
Loading…
x
Reference in New Issue
Block a user