[shape_poly] Fix the lowering of lax.dynamic_shapes with shape poly

lax.dynamic_slice clamps the start indices to ensure the access
is in bounds, but when lowering with shape polymorphism we use
stablehlo.RealDynamicSliceOp, which is a version of SliceOp
and does not have the clamping. We clamp explicitly during lowering.

This changes the lowering in very limited circumstances: when we have
a lax.dynamic_slice with shape polymorphism and a dynamic slice size,
under native lowering only, and only when the start indices were out
of bounds.
This commit is contained in:
George Necula 2023-06-14 08:40:47 +03:00
parent 0d881a4ea6
commit 645b3c4297
2 changed files with 34 additions and 16 deletions

View File

@ -1356,6 +1356,7 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
x_aval = ctx.avals_in[0]
if dtypes.is_opaque_dtype(aval_out.dtype):
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
@ -1364,23 +1365,30 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
index_avals[0].dtype if index_avals else 'int64') # type: ignore
trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape)
start_indices = (*start_indices, *trailing_zeros)
physical_aval_out = core.physical_aval(aval_out)
return dynamic_slice(ctx, physical_aval_out, x,
start_indices=start_indices)
aval_out = core.physical_aval(aval_out)
x_aval = core.physical_aval(x_aval)
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
# lax.dynamic_slice clamps the start indices, but we are going to
# lower to RealDynamicSliceOp, which is a version of SliceOp, and does
# not have the clamping behavior. We clamp start ourselves.
slice_sizes = shape_tensor(eval_dynamic_shape(ctx, slice_sizes))
clamped_start = hlo.ClampOp(
shape_tensor([0] * len(start_indices)),
shape_tensor(start_indices),
hlo.SubtractOp(
shape_tensor(eval_dynamic_shape(ctx, x_aval.shape)), # type: ignore
slice_sizes))
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
clamped_start,
hlo.AddOp(clamped_start, slice_sizes).result,
shape_tensor([1] * len(start_indices))
).result
else:
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
shape_tensor(start_indices),
hlo.AddOp(shape_tensor(start_indices),
shape_tensor(slice_sizes)).result,
shape_tensor([1] * len(slice_sizes))
).result
else:
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:

View File

@ -2054,6 +2054,16 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
poly_axes=[0, None]).both_enable_and_disable_xla(),
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[0]).both_enable_and_disable_xla(),
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[0]).both_enable_and_disable_xla(),
PolyHarness("dynamic_slice_in_dim", "idx=0",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),