Fix the type annotations and don't += a generator (it's confusing)

The code clearly needs those variables to be lists (it mutates, through
`.append` and such).

PiperOrigin-RevId: 727029815
This commit is contained in:
Jake VanderPlas 2025-02-14 12:45:13 -08:00 committed by jax authors
parent 4b94665f4f
commit 36d7f8530b

View File

@ -787,9 +787,9 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
collapsed_y_axis = 0 # Current axis in y, after collapsing.
# Scatter dimension numbers.
offset_dims: Sequence[int] = []
collapsed_slice_dims: Sequence[int] = []
start_index_map: Sequence[int] = []
offset_dims: list[int] = []
collapsed_slice_dims: list[int] = []
start_index_map: list[int] = []
use_64bit_index = (
any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and
@ -800,22 +800,22 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# Pairs of (array, start_dim) values. These will be broadcast into
# gather_indices_shape, with the array dimensions aligned to start_dim, and
# then concatenated.
gather_indices: Sequence[tuple[Array, int]] = []
gather_indices_shape: Sequence[int] = []
gather_indices: list[tuple[Array, int]] = []
gather_indices_shape: list[int] = []
# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
# the right shape.
slice_shape: Sequence[int] = []
slice_shape: list[int] = []
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
# indices, which the scatter cannot remove itself.
newaxis_dims: Sequence[int] = []
newaxis_dims: list[int] = []
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims: Sequence[int] = []
reversed_y_dims: list[int] = []
gather_slice_shape: Sequence[int] = []
gather_slice_shape: list[int] = []
for idx_pos, i in enumerate(idx):
# Handle the advanced indices here if:
@ -829,10 +829,13 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
ndim = len(shape)
start_dim = len(gather_indices_shape)
gather_indices += ((lax.convert_element_type(a, index_dtype), start_dim)
for a in advanced_index_arrs)
gather_indices.extend(
(lax.convert_element_type(a, index_dtype), start_dim)
for a in advanced_index_arrs
)
gather_indices_shape += shape
assert x_advanced_axes is not None
start_index_map.extend(x_advanced_axes)
collapsed_slice_dims.extend(x_advanced_axes)
slice_shape.extend(shape)