mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Clean up gather/scatter StableHLO lowering.
PiperOrigin-RevId: 646491586
This commit is contained in:
parent
50407e536e
commit
a4c92a454b
@ -1820,23 +1820,14 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
|
||||
assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS,
|
||||
GatherScatterMode.CLIP), mode
|
||||
try:
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
operand_batching_dims=[],
|
||||
start_indices_batching_dims=[],
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map),
|
||||
)
|
||||
# TODO: b/342182301 - Remove this branch once only the new API is supported
|
||||
except:
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map),
|
||||
)
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
operand_batching_dims=[],
|
||||
start_indices_batching_dims=[],
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map),
|
||||
)
|
||||
if not core.is_constant_shape(slice_sizes):
|
||||
slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
|
||||
# TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
|
||||
@ -2486,23 +2477,14 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
updates, dnums=dimension_numbers)
|
||||
|
||||
dnums = dimension_numbers
|
||||
try:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
input_batching_dims=[],
|
||||
scatter_indices_batching_dims=[],
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
# TODO: b/342182301 - Remove this branch once only the new API is supported
|
||||
except:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
input_batching_dims=[],
|
||||
scatter_indices_batching_dims=[],
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
result = mlir.aval_to_ir_types(aval_out)
|
||||
operand = [operand]
|
||||
updates = [updates]
|
||||
@ -2555,23 +2537,14 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
try:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
input_batching_dims=[],
|
||||
scatter_indices_batching_dims=[],
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
# TODO: b/342182301 - Remove this branch once only the new API is supported
|
||||
except:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
input_batching_dims=[],
|
||||
scatter_indices_batching_dims=[],
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
real_dtype = _real_dtype(aval_out.dtype)
|
||||
operand_type_part = mlir.aval_to_ir_types(
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
|
Loading…
x
Reference in New Issue
Block a user