mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix the type annotations and don't += a generator (it's confusing)
The code clearly needs those variables to be lists (it mutates, through `.append` and such). PiperOrigin-RevId: 727029815
This commit is contained in:
parent
4b94665f4f
commit
36d7f8530b
@ -787,9 +787,9 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
collapsed_y_axis = 0 # Current axis in y, after collapsing.
|
||||
|
||||
# Scatter dimension numbers.
|
||||
offset_dims: Sequence[int] = []
|
||||
collapsed_slice_dims: Sequence[int] = []
|
||||
start_index_map: Sequence[int] = []
|
||||
offset_dims: list[int] = []
|
||||
collapsed_slice_dims: list[int] = []
|
||||
start_index_map: list[int] = []
|
||||
|
||||
use_64bit_index = (
|
||||
any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and
|
||||
@ -800,22 +800,22 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
# Pairs of (array, start_dim) values. These will be broadcast into
|
||||
# gather_indices_shape, with the array dimensions aligned to start_dim, and
|
||||
# then concatenated.
|
||||
gather_indices: Sequence[tuple[Array, int]] = []
|
||||
gather_indices_shape: Sequence[int] = []
|
||||
gather_indices: list[tuple[Array, int]] = []
|
||||
gather_indices_shape: list[int] = []
|
||||
|
||||
# We perform three transformations to y before the scatter op, in order:
|
||||
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
|
||||
# the right shape.
|
||||
slice_shape: Sequence[int] = []
|
||||
slice_shape: list[int] = []
|
||||
|
||||
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
|
||||
# indices, which the scatter cannot remove itself.
|
||||
newaxis_dims: Sequence[int] = []
|
||||
newaxis_dims: list[int] = []
|
||||
|
||||
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
|
||||
reversed_y_dims: Sequence[int] = []
|
||||
reversed_y_dims: list[int] = []
|
||||
|
||||
gather_slice_shape: Sequence[int] = []
|
||||
gather_slice_shape: list[int] = []
|
||||
|
||||
for idx_pos, i in enumerate(idx):
|
||||
# Handle the advanced indices here if:
|
||||
@ -829,10 +829,13 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
ndim = len(shape)
|
||||
|
||||
start_dim = len(gather_indices_shape)
|
||||
gather_indices += ((lax.convert_element_type(a, index_dtype), start_dim)
|
||||
for a in advanced_index_arrs)
|
||||
gather_indices.extend(
|
||||
(lax.convert_element_type(a, index_dtype), start_dim)
|
||||
for a in advanced_index_arrs
|
||||
)
|
||||
gather_indices_shape += shape
|
||||
|
||||
assert x_advanced_axes is not None
|
||||
start_index_map.extend(x_advanced_axes)
|
||||
collapsed_slice_dims.extend(x_advanced_axes)
|
||||
slice_shape.extend(shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user