mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Integrate StableHLO at openxla/stablehlo@c44d9af8
PiperOrigin-RevId: 638559828
This commit is contained in:
parent
f72b0f0ca6
commit
9309592ac3
@ -1820,11 +1820,23 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
|
||||
assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS,
|
||||
GatherScatterMode.CLIP), mode
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map))
|
||||
try:
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
operand_batching_dims=[],
|
||||
start_indices_batching_dims=[],
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map),
|
||||
)
|
||||
# TODO: b/342182301 - Remove this branch once only the new API is supported
|
||||
except:
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
start_index_map=list(dimension_numbers.start_index_map),
|
||||
)
|
||||
if not core.is_constant_shape(slice_sizes):
|
||||
slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
|
||||
# TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
|
||||
@ -2474,11 +2486,23 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
updates, dnums=dimension_numbers)
|
||||
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
|
||||
try:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
input_batching_dims=[],
|
||||
scatter_indices_batching_dims=[],
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
# TODO: b/342182301 - Remove this branch once only the new API is supported
|
||||
except:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
result = mlir.aval_to_ir_types(aval_out)
|
||||
operand = [operand]
|
||||
updates = [updates]
|
||||
@ -2531,11 +2555,23 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
|
||||
try:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
input_batching_dims=[],
|
||||
scatter_indices_batching_dims=[],
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
# TODO: b/342182301 - Remove this branch once only the new API is supported
|
||||
except:
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
)
|
||||
real_dtype = _real_dtype(aval_out.dtype)
|
||||
operand_type_part = mlir.aval_to_ir_types(
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
|
Loading…
x
Reference in New Issue
Block a user