River Riddle 6563b1c446 Add a new dialect interface for the OperationFolder OpFolderDialectInterface.
This interface will allow for providing hooks to interrop with operation folding. The first hook, 'shouldMaterializeInto', will allow for controlling which region to insert materialized constants into. The folder will generally materialize constants into the top-level isolated region, this allows for materializing into a lower level ancestor region if it is more profitable/correct.

PiperOrigin-RevId: 266702972
2019-09-01 20:07:08 -07:00

131 lines
4.8 KiB
C++

//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "TestDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
virtual bool shouldMaterializeInto(Region *region) const {
// If this is a one region operation, then insert into it.
return isa<OneRegionOp>(region->getParentOp());
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
TestDialect::TestDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
addInterfaces<TestOpFolderDialectInterface>();
allowUnknownOperations();
}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parseIsolatedRegionOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType argInfo;
Type argType = parser->getBuilder().getIndexType();
// Parse the input operand.
if (parser->parseOperand(argInfo) ||
parser->resolveOperand(argInfo, argType, result->operands))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result->addRegion();
return parser->parseRegion(*body, argInfo, argType,
/*enableNameShadowing=*/true);
}
static void print(OpAsmPrinter *p, IsolatedRegionOp op) {
*p << "test.isolated_region ";
p->printOperand(op.getOperand());
p->shadowRegionArgs(op.region(), op.getOperand());
p->printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parsePolyForOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
if (parser->parseRegionArgumentList(ivsInfo))
return failure();
// Parse the body region.
Region *body = result->addRegion();
auto &builder = parser->getBuilder();
SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
return parser->parseRegion(*body, ivsInfo, argTypes);
}
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
namespace {
struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
};
} // end anonymous namespace
void TestOpWithRegionPattern::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TestRemoveOpWithInnerOps>(context);
}
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand();
}
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
#define GET_OP_CLASSES
#include "TestOps.cpp.inc"