inline and remove slice_mlir rules

This commit is contained in:
Roy Frostig 2023-05-12 07:59:24 -07:00
parent 129a4a5f35
commit aed77c5031
3 changed files with 24 additions and 38 deletions

View File

@ -1254,23 +1254,31 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices, limit_indices, strides) -> ir.Value:
if dtypes.is_opaque_dtype(aval_out.dtype):
return [aval_out.dtype._rules.slice_mlir(
ctx, aval_out, x, start_indices, limit_indices, strides)]
if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
start_indices = eval_dynamic_shape(ctx, start_indices)
limit_indices = eval_dynamic_shape(ctx, limit_indices)
strides = eval_dynamic_shape(ctx, strides)
return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
x,
shape_tensor(start_indices),
shape_tensor(limit_indices),
shape_tensor(strides)).result
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
trailing_zeros = [0] * len(elt_shape)
trailing_ones = [1] * len(elt_shape)
start_indices = (*start_indices, *trailing_zeros)
limit_indices = (*limit_indices, *elt_shape)
strides = (*strides, *trailing_ones)
physical_aval_out = core.physical_aval(aval_out)
return slice_op(ctx, x, physical_aval_out, start_indices=start_indices,
limit_indices=limit_indices, strides=strides)
else:
return hlo.SliceOp(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides)).result
if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
start_indices = eval_dynamic_shape(ctx, start_indices)
limit_indices = eval_dynamic_shape(ctx, limit_indices)
strides = eval_dynamic_shape(ctx, strides)
return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
x,
shape_tensor(start_indices),
shape_tensor(limit_indices),
shape_tensor(strides)).result
else:
return hlo.SliceOp(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides)).result
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:

View File

@ -521,18 +521,6 @@ class KeyTyRules:
# element-type-polymorphic primitive lowering rules
@staticmethod
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides) -> ir.Value:
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [0] * len(key_shape)
trailing_ones = [1] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
limit_indices = (*limit_indices, *key_shape)
strides = (*strides, *trailing_ones)
physical_aval_out = core.physical_aval(aval_out)
return mlir.slice_op(ctx, x, physical_aval_out,
start_indices=start_indices, limit_indices=limit_indices, strides=strides)
@staticmethod
def dynamic_slice_mlir(ctx, aval_out, x, start_indices) -> ir.Value:
index_avals = ctx.avals_in[1:]

View File

@ -2869,16 +2869,6 @@ class FooTyRules:
# element-type-polymorphic primitive lowering rules
@staticmethod
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides):
start_indices = (*start_indices, 0)
limit_indices = (*limit_indices, 2)
strides = (*strides, 1)
return hlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).result
@staticmethod
def dynamic_slice_mlir(ctx, aval_out, x, start_indices):
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))