mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 06:16:05 +00:00
[MLIR][Shape] Add canonicalization pattern for shape.rank
Replace any `rank(shape_of(tensor))` that relies on a ranked tensor with the corresponding constant `const_size`. Differential Revision: https://reviews.llvm.org/D82077
This commit is contained in:
parent
0045786f14
commit
7bca97d960
@ -130,6 +130,10 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
|
||||
let arguments = (ins IndexAttr:$value);
|
||||
let results = (outs Shape_SizeType:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value">
|
||||
];
|
||||
|
||||
let assemblyFormat = "$value attr-dict";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
@ -181,6 +185,7 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
|
||||
let assemblyFormat = "attr-dict $shape";
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
|
||||
|
@ -364,6 +364,11 @@ OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
|
||||
// ConstSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
|
||||
int64_t value) {
|
||||
build(builder, result, builder.getIndexAttr(value));
|
||||
}
|
||||
|
||||
OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
|
||||
|
||||
void ConstSizeOp::getAsmResultNames(
|
||||
@ -450,6 +455,45 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
return builder.getIndexAttr(rank);
|
||||
}
|
||||
|
||||
/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
|
||||
/// Constant folding fails in cases where only the rank is constant, not the
|
||||
/// shape itself.
|
||||
/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
|
||||
/// %rank = shape.rank %shape
|
||||
///
|
||||
/// becomes
|
||||
///
|
||||
/// %rank = shape.const_size 3
|
||||
|
||||
namespace {
|
||||
struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
|
||||
using OpRewritePattern<RankOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(RankOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
|
||||
if (!shapeOfOp)
|
||||
return failure();
|
||||
auto rankedTensorType =
|
||||
shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
|
||||
if (!rankedTensorType)
|
||||
return failure();
|
||||
int64_t rank = rankedTensorType.getRank();
|
||||
rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.insert<RankShapeOfCanonicalizationPattern>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NumElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -466,3 +466,29 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
|
||||
%rank = shape.rank %shape
|
||||
return %rank : !shape.size
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Canonicalize `rank` when shape is derived from ranked tensor.
|
||||
// CHECK-LABEL: @canonicalize_rank
|
||||
func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
|
||||
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
|
||||
// CHECK-DAG: return %[[RESULT]] : !shape.size
|
||||
%shape = shape.shape_of %arg : tensor<1x2x?xf32>
|
||||
%rank = shape.rank %shape
|
||||
return %rank : !shape.size
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Do not canonicalize `rank` when shape is derived from unranked tensor.
|
||||
// CHECK-LABEL: @dont_canonicalize_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
|
||||
func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
|
||||
// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
|
||||
// CHECK-DAG: return %[[SIZE]] : !shape.size
|
||||
%shape = shape.shape_of %arg : tensor<*xf32>
|
||||
%rank = shape.rank %shape
|
||||
return %rank : !shape.size
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user