PiperOrigin-RevId: 638559828
This commit is contained in:
Michael Levesque-Dion 2024-05-30 01:03:34 -07:00 committed by jax authors
parent f72b0f0ca6
commit 9309592ac3

View File

@ -1820,11 +1820,23 @@ def _gather_lower(ctx, operand, indices, *,
assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS,
GatherScatterMode.CLIP), mode
dnums = hlo.GatherDimensionNumbers.get(
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
offset_dims=list(dimension_numbers.offset_dims),
start_index_map=list(dimension_numbers.start_index_map))
try:
dnums = hlo.GatherDimensionNumbers.get(
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
operand_batching_dims=[],
start_indices_batching_dims=[],
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
offset_dims=list(dimension_numbers.offset_dims),
start_index_map=list(dimension_numbers.start_index_map),
)
# TODO: b/342182301 - Remove this branch once only the new API is supported
except:
dnums = hlo.GatherDimensionNumbers.get(
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
offset_dims=list(dimension_numbers.offset_dims),
start_index_map=list(dimension_numbers.start_index_map),
)
if not core.is_constant_shape(slice_sizes):
slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
# TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
@ -2474,11 +2486,23 @@ def _scatter_lower(ctx, operand, indices, updates, *,
updates, dnums=dimension_numbers)
dnums = dimension_numbers
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
try:
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
input_batching_dims=[],
scatter_indices_batching_dims=[],
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
)
# TODO: b/342182301 - Remove this branch once only the new API is supported
except:
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
)
result = mlir.aval_to_ir_types(aval_out)
operand = [operand]
updates = [updates]
@ -2531,11 +2555,23 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
try:
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
input_batching_dims=[],
scatter_indices_batching_dims=[],
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
)
# TODO: b/342182301 - Remove this branch once only the new API is supported
except:
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
)
real_dtype = _real_dtype(aval_out.dtype)
operand_type_part = mlir.aval_to_ir_types(
core.ShapedArray(aval_out.shape, real_dtype))