mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix the type annotations and don't += a generator (it's confusing)
The code clearly needs those variables to be lists (it mutates, through `.append` and such). PiperOrigin-RevId: 727029815
This commit is contained in:
parent
4b94665f4f
commit
36d7f8530b
@ -787,9 +787,9 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
|||||||
collapsed_y_axis = 0 # Current axis in y, after collapsing.
|
collapsed_y_axis = 0 # Current axis in y, after collapsing.
|
||||||
|
|
||||||
# Scatter dimension numbers.
|
# Scatter dimension numbers.
|
||||||
offset_dims: Sequence[int] = []
|
offset_dims: list[int] = []
|
||||||
collapsed_slice_dims: Sequence[int] = []
|
collapsed_slice_dims: list[int] = []
|
||||||
start_index_map: Sequence[int] = []
|
start_index_map: list[int] = []
|
||||||
|
|
||||||
use_64bit_index = (
|
use_64bit_index = (
|
||||||
any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and
|
any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and
|
||||||
@ -800,22 +800,22 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
|||||||
# Pairs of (array, start_dim) values. These will be broadcast into
|
# Pairs of (array, start_dim) values. These will be broadcast into
|
||||||
# gather_indices_shape, with the array dimensions aligned to start_dim, and
|
# gather_indices_shape, with the array dimensions aligned to start_dim, and
|
||||||
# then concatenated.
|
# then concatenated.
|
||||||
gather_indices: Sequence[tuple[Array, int]] = []
|
gather_indices: list[tuple[Array, int]] = []
|
||||||
gather_indices_shape: Sequence[int] = []
|
gather_indices_shape: list[int] = []
|
||||||
|
|
||||||
# We perform three transformations to y before the scatter op, in order:
|
# We perform three transformations to y before the scatter op, in order:
|
||||||
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
|
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
|
||||||
# the right shape.
|
# the right shape.
|
||||||
slice_shape: Sequence[int] = []
|
slice_shape: list[int] = []
|
||||||
|
|
||||||
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
|
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
|
||||||
# indices, which the scatter cannot remove itself.
|
# indices, which the scatter cannot remove itself.
|
||||||
newaxis_dims: Sequence[int] = []
|
newaxis_dims: list[int] = []
|
||||||
|
|
||||||
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
|
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
|
||||||
reversed_y_dims: Sequence[int] = []
|
reversed_y_dims: list[int] = []
|
||||||
|
|
||||||
gather_slice_shape: Sequence[int] = []
|
gather_slice_shape: list[int] = []
|
||||||
|
|
||||||
for idx_pos, i in enumerate(idx):
|
for idx_pos, i in enumerate(idx):
|
||||||
# Handle the advanced indices here if:
|
# Handle the advanced indices here if:
|
||||||
@ -829,10 +829,13 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
|||||||
ndim = len(shape)
|
ndim = len(shape)
|
||||||
|
|
||||||
start_dim = len(gather_indices_shape)
|
start_dim = len(gather_indices_shape)
|
||||||
gather_indices += ((lax.convert_element_type(a, index_dtype), start_dim)
|
gather_indices.extend(
|
||||||
for a in advanced_index_arrs)
|
(lax.convert_element_type(a, index_dtype), start_dim)
|
||||||
|
for a in advanced_index_arrs
|
||||||
|
)
|
||||||
gather_indices_shape += shape
|
gather_indices_shape += shape
|
||||||
|
|
||||||
|
assert x_advanced_axes is not None
|
||||||
start_index_map.extend(x_advanced_axes)
|
start_index_map.extend(x_advanced_axes)
|
||||||
collapsed_slice_dims.extend(x_advanced_axes)
|
collapsed_slice_dims.extend(x_advanced_axes)
|
||||||
slice_shape.extend(shape)
|
slice_shape.extend(shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user