mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
inline and remove slice_mlir
rules
This commit is contained in:
parent
129a4a5f35
commit
aed77c5031
@ -1254,23 +1254,31 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
|
||||
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
start_indices, limit_indices, strides) -> ir.Value:
|
||||
if dtypes.is_opaque_dtype(aval_out.dtype):
|
||||
return [aval_out.dtype._rules.slice_mlir(
|
||||
ctx, aval_out, x, start_indices, limit_indices, strides)]
|
||||
|
||||
if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
|
||||
start_indices = eval_dynamic_shape(ctx, start_indices)
|
||||
limit_indices = eval_dynamic_shape(ctx, limit_indices)
|
||||
strides = eval_dynamic_shape(ctx, strides)
|
||||
return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
|
||||
x,
|
||||
shape_tensor(start_indices),
|
||||
shape_tensor(limit_indices),
|
||||
shape_tensor(strides)).result
|
||||
elt_shape = aval_out.dtype._rules.physical_element_aval(
|
||||
aval_out.dtype).shape
|
||||
trailing_zeros = [0] * len(elt_shape)
|
||||
trailing_ones = [1] * len(elt_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
limit_indices = (*limit_indices, *elt_shape)
|
||||
strides = (*strides, *trailing_ones)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return slice_op(ctx, x, physical_aval_out, start_indices=start_indices,
|
||||
limit_indices=limit_indices, strides=strides)
|
||||
else:
|
||||
return hlo.SliceOp(x,
|
||||
dense_int_elements(start_indices),
|
||||
dense_int_elements(limit_indices),
|
||||
dense_int_elements(strides)).result
|
||||
if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
|
||||
start_indices = eval_dynamic_shape(ctx, start_indices)
|
||||
limit_indices = eval_dynamic_shape(ctx, limit_indices)
|
||||
strides = eval_dynamic_shape(ctx, strides)
|
||||
return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
|
||||
x,
|
||||
shape_tensor(start_indices),
|
||||
shape_tensor(limit_indices),
|
||||
shape_tensor(strides)).result
|
||||
else:
|
||||
return hlo.SliceOp(x,
|
||||
dense_int_elements(start_indices),
|
||||
dense_int_elements(limit_indices),
|
||||
dense_int_elements(strides)).result
|
||||
|
||||
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
start_indices) -> ir.Value:
|
||||
|
@ -521,18 +521,6 @@ class KeyTyRules:
|
||||
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides) -> ir.Value:
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_zeros = [0] * len(key_shape)
|
||||
trailing_ones = [1] * len(key_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
limit_indices = (*limit_indices, *key_shape)
|
||||
strides = (*strides, *trailing_ones)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return mlir.slice_op(ctx, x, physical_aval_out,
|
||||
start_indices=start_indices, limit_indices=limit_indices, strides=strides)
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, aval_out, x, start_indices) -> ir.Value:
|
||||
index_avals = ctx.avals_in[1:]
|
||||
|
@ -2869,16 +2869,6 @@ class FooTyRules:
|
||||
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides):
|
||||
start_indices = (*start_indices, 0)
|
||||
limit_indices = (*limit_indices, 2)
|
||||
strides = (*strides, 1)
|
||||
return hlo.SliceOp(x,
|
||||
mlir.dense_int_elements(start_indices),
|
||||
mlir.dense_int_elements(limit_indices),
|
||||
mlir.dense_int_elements(strides)).result
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, aval_out, x, start_indices):
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user