2019-05-23 15:11:19 -07:00
|
|
|
//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
|
|
|
|
//
|
2020-01-26 03:58:30 +00:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-23 09:35:36 -08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-05-23 15:11:19 -07:00
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-05-23 15:11:19 -07:00
|
|
|
|
|
|
|
#include "TestDialect.h"
|
2020-02-28 08:37:09 -08:00
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2019-11-01 14:26:10 -07:00
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/Module.h"
|
2019-05-23 15:11:19 -07:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2019-07-20 03:03:45 -07:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2019-09-01 20:06:42 -07:00
|
|
|
#include "mlir/Transforms/FoldUtils.h"
|
2019-09-05 12:23:45 -07:00
|
|
|
#include "mlir/Transforms/InliningUtils.h"
|
2019-11-25 11:29:21 -08:00
|
|
|
#include "llvm/ADT/StringSwitch.h"
|
2019-05-23 15:11:19 -07:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
2019-09-01 20:06:42 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TestDialect Interfaces
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2019-11-20 10:19:01 -08:00
|
|
|
|
|
|
|
// Test support for interacting with the AsmPrinter.
|
|
|
|
struct TestOpAsmInterface : public OpAsmDialectInterface {
|
|
|
|
using OpAsmDialectInterface::OpAsmDialectInterface;
|
|
|
|
|
|
|
|
void getAsmResultNames(Operation *op,
|
|
|
|
OpAsmSetValueNameFn setNameFn) const final {
|
|
|
|
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
|
|
|
|
setNameFn(asmOp, "result");
|
|
|
|
}
|
2019-12-19 22:15:31 -08:00
|
|
|
|
|
|
|
void getAsmBlockArgumentNames(Block *block,
|
|
|
|
OpAsmSetValueNameFn setNameFn) const final {
|
|
|
|
auto op = block->getParentOp();
|
|
|
|
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
|
|
|
|
if (!arrayAttr)
|
|
|
|
return;
|
|
|
|
auto args = block->getArguments();
|
|
|
|
auto e = std::min(arrayAttr.size(), args.size());
|
|
|
|
for (unsigned i = 0; i < e; ++i) {
|
2020-02-14 22:54:18 -08:00
|
|
|
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
|
2019-12-19 22:15:31 -08:00
|
|
|
setNameFn(args[i], strAttr.getValue());
|
|
|
|
}
|
|
|
|
}
|
2019-11-20 10:19:01 -08:00
|
|
|
};
|
|
|
|
|
2019-09-01 20:06:42 -07:00
|
|
|
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
|
|
|
|
using OpFolderDialectInterface::OpFolderDialectInterface;
|
|
|
|
|
|
|
|
/// Registered hook to check if the given region, which is attached to an
|
|
|
|
/// operation that is *not* isolated from above, should be used when
|
|
|
|
/// materializing constants.
|
2019-09-07 18:56:39 -07:00
|
|
|
bool shouldMaterializeInto(Region *region) const final {
|
2019-09-01 20:06:42 -07:00
|
|
|
// If this is a one region operation, then insert into it.
|
|
|
|
return isa<OneRegionOp>(region->getParentOp());
|
|
|
|
}
|
|
|
|
};
|
2019-09-05 12:23:45 -07:00
|
|
|
|
|
|
|
/// This class defines the interface for handling inlining with standard
|
|
|
|
/// operations.
|
|
|
|
struct TestInlinerInterface : public DialectInlinerInterface {
|
|
|
|
using DialectInlinerInterface::DialectInlinerInterface;
|
|
|
|
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
// Analysis Hooks
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
|
2019-10-03 23:04:56 -07:00
|
|
|
bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
|
|
|
|
// Inlining into test dialect regions is legal.
|
|
|
|
return true;
|
|
|
|
}
|
2019-09-05 12:23:45 -07:00
|
|
|
bool isLegalToInline(Operation *, Region *,
|
|
|
|
BlockAndValueMapping &) const final {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-10-03 23:10:25 -07:00
|
|
|
bool shouldAnalyzeRecursively(Operation *op) const final {
|
2019-09-05 12:23:45 -07:00
|
|
|
// Analyze recursively if this is not a functional region operation, it
|
|
|
|
// froms a separate functional scope.
|
|
|
|
return !isa<FunctionalRegionOp>(op);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
// Transformation Hooks
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Handle the given inlined terminator by replacing it with a new operation
|
|
|
|
/// as necessary.
|
|
|
|
void handleTerminator(Operation *op,
|
2019-12-23 14:45:01 -08:00
|
|
|
ArrayRef<Value> valuesToRepl) const final {
|
2019-09-05 12:23:45 -07:00
|
|
|
// Only handle "test.return" here.
|
|
|
|
auto returnOp = dyn_cast<TestReturnOp>(op);
|
|
|
|
if (!returnOp)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Replace the values directly with the return operands.
|
|
|
|
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
|
|
|
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
2020-01-11 08:54:04 -08:00
|
|
|
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
2019-09-05 12:23:45 -07:00
|
|
|
}
|
2019-10-03 23:10:25 -07:00
|
|
|
|
|
|
|
/// Attempt to materialize a conversion for a type mismatch between a call
|
|
|
|
/// from this dialect, and a callable region. This method should generate an
|
|
|
|
/// operation that takes 'input' as the only operand, and produces a single
|
|
|
|
/// result of 'resultType'. If a conversion can not be generated, nullptr
|
|
|
|
/// should be returned.
|
2019-12-23 14:45:01 -08:00
|
|
|
Operation *materializeCallConversion(OpBuilder &builder, Value input,
|
2019-10-03 23:10:25 -07:00
|
|
|
Type resultType,
|
|
|
|
Location conversionLoc) const final {
|
|
|
|
// Only allow conversion for i16/i32 types.
|
2020-01-10 14:48:24 -05:00
|
|
|
if (!(resultType.isSignlessInteger(16) ||
|
|
|
|
resultType.isSignlessInteger(32)) ||
|
|
|
|
!(input.getType().isSignlessInteger(16) ||
|
|
|
|
input.getType().isSignlessInteger(32)))
|
2019-10-03 23:10:25 -07:00
|
|
|
return nullptr;
|
|
|
|
return builder.create<TestCastOp>(conversionLoc, resultType, input);
|
|
|
|
}
|
2019-09-05 12:23:45 -07:00
|
|
|
};
|
2019-09-01 20:06:42 -07:00
|
|
|
} // end anonymous namespace
|
|
|
|
|
2019-05-23 15:11:19 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TestDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
TestDialect::TestDialect(MLIRContext *context)
|
2020-03-16 18:31:58 -07:00
|
|
|
: Dialect(getDialectNamespace(), context) {
|
2019-05-23 15:11:19 -07:00
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "TestOps.cpp.inc"
|
|
|
|
>();
|
2019-11-20 10:19:01 -08:00
|
|
|
addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
|
|
|
|
TestInlinerInterface>();
|
2019-06-19 13:58:31 -07:00
|
|
|
allowUnknownOperations();
|
2019-05-23 15:11:19 -07:00
|
|
|
}
|
|
|
|
|
2019-11-12 11:57:47 -08:00
|
|
|
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
|
|
|
|
NamedAttribute namedAttr) {
|
|
|
|
if (namedAttr.first == "test.invalid_attr")
|
|
|
|
return op->emitError() << "invalid to use 'test.invalid_attr'";
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-10-18 16:02:56 -07:00
|
|
|
LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
|
|
|
|
unsigned regionIndex,
|
|
|
|
unsigned argIndex,
|
|
|
|
NamedAttribute namedAttr) {
|
|
|
|
if (namedAttr.first == "test.invalid_attr")
|
|
|
|
return op->emitError() << "invalid to use 'test.invalid_attr'";
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
|
|
|
|
unsigned resultIndex,
|
|
|
|
NamedAttribute namedAttr) {
|
|
|
|
if (namedAttr.first == "test.invalid_attr")
|
|
|
|
return op->emitError() << "invalid to use 'test.invalid_attr'";
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-03-05 12:40:23 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TestBranchOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
|
|
|
|
assert(index == 0 && "invalid successor index");
|
|
|
|
return getOperands();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool TestBranchOp::canEraseSuccessorOperand() { return true; }
|
|
|
|
|
2019-08-19 15:26:43 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Test IsolatedRegionOp - parse passthrough region arguments.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &result) {
|
2019-08-19 15:26:43 -07:00
|
|
|
OpAsmParser::OperandType argInfo;
|
2019-09-20 11:36:49 -07:00
|
|
|
Type argType = parser.getBuilder().getIndexType();
|
2019-08-19 15:26:43 -07:00
|
|
|
|
|
|
|
// Parse the input operand.
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(argInfo) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(argInfo, argType, result.operands))
|
2019-08-19 15:26:43 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the body region, and reuse the operand info as the argument info.
|
2019-09-20 19:47:05 -07:00
|
|
|
Region *body = result.addRegion();
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.parseRegion(*body, argInfo, argType,
|
|
|
|
/*enableNameShadowing=*/true);
|
2019-08-19 15:26:43 -07:00
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
|
|
|
|
p << "test.isolated_region ";
|
|
|
|
p.printOperand(op.getOperand());
|
|
|
|
p.shadowRegionArgs(op.region(), op.getOperand());
|
|
|
|
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
|
2019-08-23 10:35:24 -07:00
|
|
|
}
|
|
|
|
|
2020-04-29 05:38:23 +05:30
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Test PolyhedralScopeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static ParseResult parsePolyhedralScopeOp(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
// Parse the body region, and reuse the operand info as the argument info.
|
|
|
|
Region *body = result.addRegion();
|
|
|
|
return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, PolyhedralScopeOp op) {
|
|
|
|
p << "test.polyhedral_scope ";
|
|
|
|
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
|
|
|
|
}
|
|
|
|
|
2019-09-08 23:39:34 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-09-17 17:54:54 -07:00
|
|
|
// Test parser.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &result) {
|
2019-09-17 17:54:54 -07:00
|
|
|
StringRef keyword;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseKeyword(&keyword))
|
2019-09-17 17:54:54 -07:00
|
|
|
return failure();
|
2019-09-20 19:47:05 -07:00
|
|
|
result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
|
2019-09-17 17:54:54 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
|
|
|
|
p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
|
2019-09-17 17:54:54 -07:00
|
|
|
}
|
|
|
|
|
2019-09-08 23:39:34 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-09-17 17:54:54 -07:00
|
|
|
// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
|
2019-09-08 23:39:34 -07:00
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &result) {
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseKeyword("wraps"))
|
2019-09-08 23:39:34 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the wrapped op in a region
|
2019-09-20 19:47:05 -07:00
|
|
|
Region &body = *result.addRegion();
|
2019-09-08 23:39:34 -07:00
|
|
|
body.push_back(new Block);
|
|
|
|
Block &block = body.back();
|
2019-09-20 11:36:49 -07:00
|
|
|
Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
|
2019-09-08 23:39:34 -07:00
|
|
|
if (!wrapped_op)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Create a return terminator in the inner region, pass as operand to the
|
|
|
|
// terminator the returned values from the wrapped operation.
|
2019-12-23 14:45:01 -08:00
|
|
|
SmallVector<Value, 8> return_operands(wrapped_op->getResults());
|
2019-09-20 11:36:49 -07:00
|
|
|
OpBuilder builder(parser.getBuilder().getContext());
|
2019-09-08 23:39:34 -07:00
|
|
|
builder.setInsertionPointToEnd(&block);
|
2019-10-28 15:11:00 -07:00
|
|
|
builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
|
2019-09-08 23:39:34 -07:00
|
|
|
|
|
|
|
// Get the results type for the wrapping op from the terminator operands.
|
|
|
|
Operation &return_op = body.back().back();
|
2019-09-20 19:47:05 -07:00
|
|
|
result.types.append(return_op.operand_type_begin(),
|
|
|
|
return_op.operand_type_end());
|
2019-10-28 15:11:00 -07:00
|
|
|
|
|
|
|
// Use the location of the wrapped op for the "test.wrapping_region" op.
|
|
|
|
result.location = wrapped_op->getLoc();
|
|
|
|
|
2019-09-08 23:39:34 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
|
|
|
|
p << op.getOperationName() << " wraps ";
|
|
|
|
p.printGenericOp(&op.region().front().front());
|
2019-09-08 23:39:34 -07:00
|
|
|
}
|
|
|
|
|
2019-07-22 17:41:38 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Test PolyForOp - parse list of region arguments.
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-20 10:19:01 -08:00
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
|
2019-07-22 17:41:38 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
|
|
|
|
// Parse list of region arguments without a delimiter.
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseRegionArgumentList(ivsInfo))
|
2019-07-22 17:41:38 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the body region.
|
2019-09-20 19:47:05 -07:00
|
|
|
Region *body = result.addRegion();
|
2019-09-20 11:36:49 -07:00
|
|
|
auto &builder = parser.getBuilder();
|
2019-07-22 17:41:38 -07:00
|
|
|
SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.parseRegion(*body, ivsInfo, argTypes);
|
2019-07-22 17:41:38 -07:00
|
|
|
}
|
|
|
|
|
2019-08-06 11:08:22 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Test removing op with inner ops.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2019-08-26 09:44:09 -07:00
|
|
|
struct TestRemoveOpWithInnerOps
|
|
|
|
: public OpRewritePattern<TestOpWithRegionPattern> {
|
|
|
|
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
|
2019-08-06 11:08:22 -07:00
|
|
|
|
2020-03-17 20:07:55 -07:00
|
|
|
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2019-10-16 09:50:28 -07:00
|
|
|
rewriter.eraseOp(op);
|
2020-03-17 20:07:55 -07:00
|
|
|
return success();
|
2019-08-06 11:08:22 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
2019-08-26 09:44:09 -07:00
|
|
|
void TestOpWithRegionPattern::getCanonicalizationPatterns(
|
2019-08-06 11:08:22 -07:00
|
|
|
OwningRewritePatternList &results, MLIRContext *context) {
|
|
|
|
results.insert<TestRemoveOpWithInnerOps>(context);
|
|
|
|
}
|
|
|
|
|
2019-08-26 09:44:09 -07:00
|
|
|
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return operand();
|
|
|
|
}
|
|
|
|
|
2019-10-09 20:42:32 -07:00
|
|
|
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
|
|
|
|
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
|
2019-12-23 14:45:01 -08:00
|
|
|
for (Value input : this->operands()) {
|
2019-10-09 20:42:32 -07:00
|
|
|
results.push_back(input);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-06 14:42:16 -08:00
|
|
|
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
|
2020-01-08 18:48:38 -08:00
|
|
|
MLIRContext *, Optional<Location> location, ValueRange operands,
|
2019-12-09 08:57:27 -08:00
|
|
|
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
2020-02-28 10:59:34 -08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2020-01-11 08:54:04 -08:00
|
|
|
if (operands[0].getType() != operands[1].getType()) {
|
2019-12-06 14:42:16 -08:00
|
|
|
return emitOptionalError(location, "operand type mismatch ",
|
2020-01-11 08:54:04 -08:00
|
|
|
operands[0].getType(), " vs ",
|
|
|
|
operands[1].getType());
|
2019-11-07 07:51:12 -08:00
|
|
|
}
|
2020-02-28 10:59:34 -08:00
|
|
|
inferredReturnTypes.assign({operands[0].getType()});
|
2019-12-06 14:42:16 -08:00
|
|
|
return success();
|
2019-09-29 17:28:29 -07:00
|
|
|
}
|
|
|
|
|
2020-01-08 18:48:38 -08:00
|
|
|
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
|
|
|
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
|
|
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
2020-02-28 10:59:34 -08:00
|
|
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
2020-02-28 08:37:09 -08:00
|
|
|
// Create return type consisting of the last element of the first operand.
|
|
|
|
auto operandType = *operands.getTypes().begin();
|
|
|
|
auto sval = operandType.dyn_cast<ShapedType>();
|
|
|
|
if (!sval) {
|
|
|
|
return emitOptionalError(location, "only shaped type operands allowed");
|
2020-01-08 18:48:38 -08:00
|
|
|
}
|
2020-02-28 08:37:09 -08:00
|
|
|
int64_t dim =
|
|
|
|
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
|
2020-01-08 18:48:38 -08:00
|
|
|
auto type = IntegerType::get(17, context);
|
2020-02-28 10:59:34 -08:00
|
|
|
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
|
2020-02-28 08:37:09 -08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
|
|
|
|
OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
|
|
|
|
shapes = SmallVector<Value, 1>{
|
|
|
|
builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
|
2020-01-08 18:48:38 -08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
[mlir][SideEffects] Define a set of interfaces and traits for defining side effects
This revision introduces the infrastructure for defining side-effects and attaching them to operations. This infrastructure allows for defining different types of side effects, that don't interact with each other, but use the same internal mechanisms. At the base of this is an interface that allows operations to specify the different effect instances that are exhibited by a specific operation instance. An effect instance is comprised of the following:
* Effect: The specific effect being applied.
For memory related effects this may be reading from memory, storing to memory, etc.
* Value: A specific value, either operand/result/region argument, the effect pertains to.
* Resource: This is a global entity that represents the domain within which the effect is being applied.
MLIR serves many different abstractions, which cover many different domains. Simple effects are may have very different context, for example writing to an in-memory buffer vs a database. This revision defines uses this infrastructure to define a set of initial MemoryEffects. The are effects that generally correspond to memory of some kind; Allocate, Free, Read, Write.
This set of memory effects will be used in follow revisions to generalize various parts of the compiler, and make others more powerful(e.g. DCE).
This infrastructure was originally proposed here:
https://groups.google.com/a/tensorflow.org/g/mlir/c/v2mNl4vFCUM
Differential Revision: https://reviews.llvm.org/D74439
2020-03-06 13:53:16 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Test SideEffect interfaces
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// A test resource for side effects.
|
|
|
|
struct TestResource : public SideEffects::Resource::Base<TestResource> {
|
|
|
|
StringRef getName() final { return "<Test>"; }
|
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
void SideEffectOp::getEffects(
|
|
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
|
|
// Check for an effects attribute on the op instance.
|
|
|
|
ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
|
|
|
|
if (!effectsAttr)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// If there is one, it is an array of dictionary attributes that hold
|
|
|
|
// information on the effects of this operation.
|
|
|
|
for (Attribute element : effectsAttr) {
|
|
|
|
DictionaryAttr effectElement = element.cast<DictionaryAttr>();
|
|
|
|
|
|
|
|
// Get the specific memory effect.
|
|
|
|
MemoryEffects::Effect *effect =
|
|
|
|
llvm::StringSwitch<MemoryEffects::Effect *>(
|
|
|
|
effectElement.get("effect").cast<StringAttr>().getValue())
|
|
|
|
.Case("allocate", MemoryEffects::Allocate::get())
|
|
|
|
.Case("free", MemoryEffects::Free::get())
|
|
|
|
.Case("read", MemoryEffects::Read::get())
|
|
|
|
.Case("write", MemoryEffects::Write::get());
|
|
|
|
|
|
|
|
// Check for a result to affect.
|
|
|
|
Value value;
|
|
|
|
if (effectElement.get("on_result"))
|
|
|
|
value = getResult();
|
|
|
|
|
|
|
|
// Check for a non-default resource to use.
|
|
|
|
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
|
|
|
|
if (effectElement.get("test_resource"))
|
|
|
|
resource = TestResource::get();
|
|
|
|
|
|
|
|
effects.emplace_back(effect, value, resource);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
Add support for custom op parser/printer hooks to know about result names.
Summary:
This allows the custom parser/printer hooks to do interesting things with
the SSA names. This patch:
- Adds a new 'getResultName' method to OpAsmParser that allows a parser
implementation to get information about its result names, along with
a getNumResults() method that allows op parser impls to know how many
results are expected.
- Adds a OpAsmPrinter::printOperand overload that takes an explicit stream.
- Adds a test.string_attr_pretty_name operation that uses these hooks to
do fancy things with the result name.
Reviewers: rriddle!
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76205
2020-03-15 17:13:59 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StringAttrPrettyNameOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// This op has fancy handling of its SSA result name.
|
|
|
|
static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
// Add the result types.
|
|
|
|
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
|
|
|
|
result.addTypes(parser.getBuilder().getIntegerType(32));
|
|
|
|
|
|
|
|
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// If the attribute dictionary contains no 'names' attribute, infer it from
|
|
|
|
// the SSA name (if specified).
|
|
|
|
bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
|
Eliminate all uses of Identifier::is() in the source tree, this doesn't remove the definition of it (yet). NFC.
Reviewers: mravishankar, antiagainst, herhut, rriddle!
Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, csigg, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, bader, grosul1, frgossen, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D78042
2020-04-13 11:17:35 -07:00
|
|
|
return attr.first == "names";
|
Add support for custom op parser/printer hooks to know about result names.
Summary:
This allows the custom parser/printer hooks to do interesting things with
the SSA names. This patch:
- Adds a new 'getResultName' method to OpAsmParser that allows a parser
implementation to get information about its result names, along with
a getNumResults() method that allows op parser impls to know how many
results are expected.
- Adds a OpAsmPrinter::printOperand overload that takes an explicit stream.
- Adds a test.string_attr_pretty_name operation that uses these hooks to
do fancy things with the result name.
Reviewers: rriddle!
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76205
2020-03-15 17:13:59 -07:00
|
|
|
});
|
|
|
|
|
|
|
|
// If there was no name specified, check to see if there was a useful name
|
|
|
|
// specified in the asm file.
|
|
|
|
if (hadNames || parser.getNumResults() == 0)
|
|
|
|
return success();
|
|
|
|
|
|
|
|
SmallVector<StringRef, 4> names;
|
|
|
|
auto *context = result.getContext();
|
|
|
|
|
|
|
|
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
|
|
|
|
auto resultName = parser.getResultName(i);
|
|
|
|
StringRef nameStr;
|
|
|
|
if (!resultName.first.empty() && !isdigit(resultName.first[0]))
|
|
|
|
nameStr = resultName.first;
|
|
|
|
|
|
|
|
names.push_back(nameStr);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
|
|
|
|
result.attributes.push_back({Identifier::get("names", context), namesAttr});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
|
|
|
|
p << "test.string_attr_pretty_name";
|
|
|
|
|
|
|
|
// Note that we only need to print the "name" attribute if the asmprinter
|
|
|
|
// result name disagrees with it. This can happen in strange cases, e.g.
|
|
|
|
// when there are conflicts.
|
|
|
|
bool namesDisagree = op.names().size() != op.getNumResults();
|
|
|
|
|
|
|
|
SmallString<32> resultNameStr;
|
|
|
|
for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
|
|
|
|
resultNameStr.clear();
|
|
|
|
llvm::raw_svector_ostream tmpStream(resultNameStr);
|
|
|
|
p.printOperand(op.getResult(i), tmpStream);
|
|
|
|
|
|
|
|
auto expectedName = op.names()[i].dyn_cast<StringAttr>();
|
|
|
|
if (!expectedName ||
|
|
|
|
tmpStream.str().drop_front() != expectedName.getValue()) {
|
|
|
|
namesDisagree = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (namesDisagree)
|
|
|
|
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
|
|
|
else
|
|
|
|
p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
|
|
|
|
}
|
|
|
|
|
|
|
|
// We set the SSA name in the asm syntax to the contents of the name
|
|
|
|
// attribute.
|
|
|
|
void StringAttrPrettyNameOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
|
|
|
|
auto value = names();
|
|
|
|
for (size_t i = 0, e = value.size(); i != e; ++i)
|
|
|
|
if (auto str = value[i].dyn_cast<StringAttr>())
|
|
|
|
if (!str.getValue().empty())
|
|
|
|
setNameFn(getResult(i), str.getValue());
|
|
|
|
}
|
|
|
|
|
[mlir][SideEffects] Define a set of interfaces and traits for defining side effects
This revision introduces the infrastructure for defining side-effects and attaching them to operations. This infrastructure allows for defining different types of side effects, that don't interact with each other, but use the same internal mechanisms. At the base of this is an interface that allows operations to specify the different effect instances that are exhibited by a specific operation instance. An effect instance is comprised of the following:
* Effect: The specific effect being applied.
For memory related effects this may be reading from memory, storing to memory, etc.
* Value: A specific value, either operand/result/region argument, the effect pertains to.
* Resource: This is a global entity that represents the domain within which the effect is being applied.
MLIR serves many different abstractions, which cover many different domains. Simple effects are may have very different context, for example writing to an in-memory buffer vs a database. This revision defines uses this infrastructure to define a set of initial MemoryEffects. The are effects that generally correspond to memory of some kind; Allocate, Free, Read, Write.
This set of memory effects will be used in follow revisions to generalize various parts of the compiler, and make others more powerful(e.g. DCE).
This infrastructure was originally proposed here:
https://groups.google.com/a/tensorflow.org/g/mlir/c/v2mNl4vFCUM
Differential Revision: https://reviews.llvm.org/D74439
2020-03-06 13:53:16 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Dialect Registration
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-05-23 15:11:19 -07:00
|
|
|
// Static initialization for Test dialect registration.
|
|
|
|
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
|
|
|
|
|
2019-11-25 11:29:21 -08:00
|
|
|
#include "TestOpEnums.cpp.inc"
|
2020-04-27 17:52:59 -07:00
|
|
|
#include "TestOpStructs.cpp.inc"
|
2019-11-25 11:29:21 -08:00
|
|
|
|
2019-05-23 15:11:19 -07:00
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "TestOps.cpp.inc"
|