mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Fix the lowering of lax.dynamic_shapes with shape poly
lax.dynamic_slice clamps the start indices to ensure the access is in bounds, but when lowering with shape polymorphism we use stablehlo.RealDynamicSliceOp, which is a version of SliceOp and does not have the clamping. We clamp explicitly during lowering. This changes the lowering in very limited circumstances: when we have a lax.dynamic_slice with shape polymorphism and a dynamic slice size, under native lowering only, and only when the start indices were out of bounds.
This commit is contained in:
parent
0d881a4ea6
commit
645b3c4297
@ -1356,6 +1356,7 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
|
||||
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
start_indices) -> ir.Value:
|
||||
x_aval = ctx.avals_in[0]
|
||||
if dtypes.is_opaque_dtype(aval_out.dtype):
|
||||
elt_shape = aval_out.dtype._rules.physical_element_aval(
|
||||
aval_out.dtype).shape
|
||||
@ -1364,23 +1365,30 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
index_avals[0].dtype if index_avals else 'int64') # type: ignore
|
||||
trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return dynamic_slice(ctx, physical_aval_out, x,
|
||||
start_indices=start_indices)
|
||||
aval_out = core.physical_aval(aval_out)
|
||||
x_aval = core.physical_aval(x_aval)
|
||||
|
||||
slice_sizes = aval_out.shape
|
||||
if not core.is_constant_shape(slice_sizes):
|
||||
# lax.dynamic_slice clamps the start indices, but we are going to
|
||||
# lower to RealDynamicSliceOp, which is a version of SliceOp, and does
|
||||
# not have the clamping behavior. We clamp start ourselves.
|
||||
slice_sizes = shape_tensor(eval_dynamic_shape(ctx, slice_sizes))
|
||||
clamped_start = hlo.ClampOp(
|
||||
shape_tensor([0] * len(start_indices)),
|
||||
shape_tensor(start_indices),
|
||||
hlo.SubtractOp(
|
||||
shape_tensor(eval_dynamic_shape(ctx, x_aval.shape)), # type: ignore
|
||||
slice_sizes))
|
||||
return hlo.RealDynamicSliceOp(
|
||||
aval_to_ir_type(aval_out), x,
|
||||
clamped_start,
|
||||
hlo.AddOp(clamped_start, slice_sizes).result,
|
||||
shape_tensor([1] * len(start_indices))
|
||||
).result
|
||||
else:
|
||||
slice_sizes = aval_out.shape
|
||||
if not core.is_constant_shape(slice_sizes):
|
||||
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
|
||||
return hlo.RealDynamicSliceOp(
|
||||
aval_to_ir_type(aval_out), x,
|
||||
shape_tensor(start_indices),
|
||||
hlo.AddOp(shape_tensor(start_indices),
|
||||
shape_tensor(slice_sizes)).result,
|
||||
shape_tensor([1] * len(slice_sizes))
|
||||
).result
|
||||
else:
|
||||
return hlo.DynamicSliceOp(x, start_indices,
|
||||
dense_int_elements(slice_sizes)).result
|
||||
return hlo.DynamicSliceOp(x, start_indices,
|
||||
dense_int_elements(slice_sizes)).result
|
||||
|
||||
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
||||
start_indices) -> ir.Value:
|
||||
|
@ -2054,6 +2054,16 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
poly_axes=[0, None]).both_enable_and_disable_xla(),
|
||||
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]).both_enable_and_disable_xla(),
|
||||
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]).both_enable_and_disable_xla(),
|
||||
PolyHarness("dynamic_slice_in_dim", "idx=0",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
|
||||
|
Loading…
x
Reference in New Issue
Block a user