[MLIR][Shape] Add canonicalization pattern for shape.rank

Replace any `rank(shape_of(tensor))` that relies on a ranked tensor with the
corresponding constant `const_size`.

Differential Revision: https://reviews.llvm.org/D82077
This commit is contained in:
Frederik Gossen 2020-06-25 08:37:18 +00:00
parent 0045786f14
commit 7bca97d960
3 changed files with 75 additions and 0 deletions

View File

@ -130,6 +130,10 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
let arguments = (ins IndexAttr:$value);
let results = (outs Shape_SizeType:$result);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value">
];
let assemblyFormat = "$value attr-dict";
let hasFolder = 1;
}
@ -181,6 +185,7 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
let assemblyFormat = "attr-dict $shape";
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {

View File

@ -364,6 +364,11 @@ OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
// ConstSizeOp
//===----------------------------------------------------------------------===//
void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
int64_t value) {
build(builder, result, builder.getIndexAttr(value));
}
OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
void ConstSizeOp::getAsmResultNames(
@ -450,6 +455,45 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
return builder.getIndexAttr(rank);
}
/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
/// Constant folding fails in cases where only the rank is constant, not the
/// shape itself.
/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
///
/// Example:
///
/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
/// %rank = shape.rank %shape
///
/// becomes
///
/// %rank = shape.const_size 3
namespace {
struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
using OpRewritePattern<RankOp>::OpRewritePattern;
LogicalResult matchAndRewrite(RankOp op,
PatternRewriter &rewriter) const override {
auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
return failure();
auto rankedTensorType =
shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType)
return failure();
int64_t rank = rankedTensorType.getRank();
rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
return success();
}
};
} // namespace
void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
patterns.insert<RankShapeOfCanonicalizationPattern>(context);
}
//===----------------------------------------------------------------------===//
// NumElementsOp
//===----------------------------------------------------------------------===//

View File

@ -466,3 +466,29 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
%rank = shape.rank %shape
return %rank : !shape.size
}
// -----
// Canonicalize `rank` when shape is derived from ranked tensor.
// CHECK-LABEL: @canonicalize_rank
func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
// CHECK-DAG: return %[[RESULT]] : !shape.size
%shape = shape.shape_of %arg : tensor<1x2x?xf32>
%rank = shape.rank %shape
return %rank : !shape.size
}
// -----
// Do not canonicalize `rank` when shape is derived from unranked tensor.
// CHECK-LABEL: @dont_canonicalize_rank
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
// CHECK-DAG: return %[[SIZE]] : !shape.size
%shape = shape.shape_of %arg : tensor<*xf32>
%rank = shape.rank %shape
return %rank : !shape.size
}