diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0031341ab..cc56e9402 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10835,6 +10835,12 @@ def take_along_axis( collapsed_slice_dims = [] operand_batching_dims = [] start_indices_batching_dims = [] + + # We will squeeze the array. i is the index of the unsqueezed shape, while + # new_i is the index of the squeezed shape. j is the index of the gather + # indices. + dims_to_squeeze = [] + new_i = 0 j = 0 for i in range(rank): if i == axis_int: @@ -10842,22 +10848,20 @@ def take_along_axis( indices = _normalize_index(indices, axis_size) gather_indices.append(lax.reshape(indices, gather_index_shape)) slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + start_index_map.append(new_i) + collapsed_slice_dims.append(new_i) + new_i += 1 j += 1 elif core.definitely_equal(idx_shape[i], 1): # If idx_shape[i] == 1, we can just take the entirety of the arr's axis # and avoid forming an iota index. offset_dims.append(i) slice_sizes.append(arr_shape[i]) + new_i += 1 elif core.definitely_equal(arr_shape[i], 1): - # If the array dimension is 1 but the index dimension is not, we - # broadcast the array dimension to the index dimension by repeatedly - # gathering the first element. - gather_indices.append(zeros(gather_index_shape, dtype=index_dtype)) - slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + # If the array dimension is 1 but the index dimension is not, we will + # squeeze this dimension. + dims_to_squeeze.append(i) j += 1 else: # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both @@ -10866,10 +10870,13 @@ def take_along_axis( slice_sizes.append(0) else: slice_sizes.append(1) - operand_batching_dims.append(i) + operand_batching_dims.append(new_i) start_indices_batching_dims.append(j) + new_i += 1 j += 1 + # Squeeze a to remove singleton dimensions. + a = lax.squeeze(a, dims_to_squeeze) gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 1d6cfad8a..45698d3b1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4721,6 +4721,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): q1 = np.take_along_axis( h, g, axis=-1) np.testing.assert_equal(q0, q1) + def testTakeAlongAxisInputTensorHasSingletonDimension(self): + h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32) + g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8) + q0 = jnp.take_along_axis(h, g, axis=-2) + q1 = np.take_along_axis( h, g, axis=-2) + np.testing.assert_equal(q0, q1) + def testTakeAlongAxisOutOfBounds(self): x = jnp.arange(10, dtype=jnp.float32) idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])