2018-07-05 09:12:11 -07:00
|
|
|
//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
2018-10-10 14:23:30 -07:00
|
|
|
#include "mlir/StandardOps/StandardOps.h"
|
2018-09-24 10:23:02 -07:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
2018-07-24 10:13:31 -07:00
|
|
|
#include "mlir/IR/AffineMap.h"
|
2018-07-25 11:15:20 -07:00
|
|
|
#include "mlir/IR/Builders.h"
|
2018-10-10 14:23:30 -07:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2018-10-29 10:22:49 -07:00
|
|
|
#include "mlir/IR/Matchers.h"
|
2018-07-24 16:07:22 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2018-10-25 16:44:04 -07:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2019-01-03 14:29:52 -08:00
|
|
|
#include "mlir/IR/StandardTypes.h"
|
2018-12-27 14:35:10 -08:00
|
|
|
#include "mlir/IR/Value.h"
|
2018-10-03 10:07:54 -07:00
|
|
|
#include "mlir/Support/MathExtras.h"
|
2018-08-09 12:28:58 -07:00
|
|
|
#include "mlir/Support/STLExtras.h"
|
2018-11-08 04:02:00 -08:00
|
|
|
#include "llvm/ADT/StringSwitch.h"
|
2018-07-05 09:12:11 -07:00
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
|
2018-10-21 19:49:31 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StandardOpsDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
2019-01-02 09:26:35 -08:00
|
|
|
: Dialect(/*namePrefix=*/"", context) {
|
2019-01-04 01:34:16 -08:00
|
|
|
addOperations<AllocOp, CallOp, CallIndirectOp, CmpIOp, DeallocOp, DimOp,
|
|
|
|
DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp, MemRefCastOp,
|
|
|
|
SelectOp, StoreOp, TensorCastOp,
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "mlir/StandardOps/standard_ops.inc"
|
|
|
|
>();
|
2018-10-21 19:49:31 -07:00
|
|
|
}
|
|
|
|
|
2018-10-25 16:44:04 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common canonicalization pattern support logic
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// This is a common class used for patterns of the form
|
|
|
|
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
|
|
|
|
/// into the root operation directly.
|
2018-11-28 15:09:39 -08:00
|
|
|
struct MemRefCastFolder : public RewritePattern {
|
2018-10-25 16:44:04 -07:00
|
|
|
/// The rootOpName is the name of the root operation to match against.
|
|
|
|
MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
|
2018-11-28 15:09:39 -08:00
|
|
|
: RewritePattern(rootOpName, 1, context) {}
|
2018-10-25 16:44:04 -07:00
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
PatternMatchResult match(OperationInst *op) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
for (auto *operand : op->getOperands())
|
2018-10-30 10:57:50 -07:00
|
|
|
if (matchPattern(operand, m_Op<MemRefCastOp>()))
|
2018-10-29 10:22:49 -07:00
|
|
|
return matchSuccess();
|
2018-10-25 16:44:04 -07:00
|
|
|
|
|
|
|
return matchFailure();
|
|
|
|
}
|
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
|
2018-12-27 21:21:41 -08:00
|
|
|
if (auto *memref = op->getOperand(i)->getDefiningInst())
|
2018-10-25 16:44:04 -07:00
|
|
|
if (auto cast = memref->dyn_cast<MemRefCastOp>())
|
|
|
|
op->setOperand(i, cast->getOperand());
|
|
|
|
rewriter.updatedRootInPlace(op);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace.
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AddFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-09-19 21:35:11 -07:00
|
|
|
assert(operands.size() == 2 && "addf takes two operands");
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
2018-11-15 17:53:51 -08:00
|
|
|
if (lhs.getType() == rhs.getType())
|
|
|
|
return FloatAttr::get(lhs.getType(), lhs.getValue() + rhs.getValue());
|
2018-09-19 21:35:11 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AddIOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-03 09:43:13 -07:00
|
|
|
assert(operands.size() == 2 && "addi takes two operands");
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
2018-11-15 17:53:51 -08:00
|
|
|
if (lhs.getType() == rhs.getType())
|
|
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue() + rhs.getValue());
|
2018-10-03 09:43:13 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-25 16:44:04 -07:00
|
|
|
namespace {
|
|
|
|
/// addi(x, 0) -> x
|
|
|
|
///
|
2018-11-28 15:09:39 -08:00
|
|
|
struct SimplifyAddX0 : public RewritePattern {
|
2018-10-25 16:44:04 -07:00
|
|
|
SimplifyAddX0(MLIRContext *context)
|
2018-11-28 15:09:39 -08:00
|
|
|
: RewritePattern(AddIOp::getOperationName(), 1, context) {}
|
2018-10-25 16:44:04 -07:00
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
PatternMatchResult match(OperationInst *op) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
auto addi = op->cast<AddIOp>();
|
2018-10-29 10:22:49 -07:00
|
|
|
|
2018-10-30 10:57:50 -07:00
|
|
|
if (matchPattern(addi->getOperand(1), m_Zero()))
|
2018-10-29 10:22:49 -07:00
|
|
|
return matchSuccess();
|
2018-10-25 16:44:04 -07:00
|
|
|
|
|
|
|
return matchFailure();
|
|
|
|
}
|
2018-12-27 21:21:41 -08:00
|
|
|
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
2018-11-24 07:40:55 -08:00
|
|
|
rewriter.replaceOp(op, op->getOperand(0));
|
2018-10-25 16:44:04 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace.
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
results.push_back(std::make_unique<SimplifyAddX0>(context));
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AllocOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-31 14:49:38 -07:00
|
|
|
void AllocOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
MemRefType memrefType, ArrayRef<Value *> operands) {
|
2018-08-31 14:49:38 -07:00
|
|
|
result->addOperands(operands);
|
|
|
|
result->types.push_back(memrefType);
|
|
|
|
}
|
|
|
|
|
2018-07-30 13:08:05 -07:00
|
|
|
void AllocOp::print(OpAsmPrinter *p) const {
|
2018-10-30 14:59:22 -07:00
|
|
|
MemRefType type = getType();
|
2018-07-30 13:08:05 -07:00
|
|
|
*p << "alloc";
|
|
|
|
// Print dynamic dimension operands.
|
|
|
|
printDimAndSymbolList(operand_begin(), operand_end(),
|
2018-10-30 14:59:22 -07:00
|
|
|
type.getNumDynamicDims(), p);
|
2018-08-02 16:54:36 -07:00
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << type;
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-10-30 14:59:22 -07:00
|
|
|
MemRefType type;
|
2018-07-30 13:08:05 -07:00
|
|
|
|
2018-07-31 16:21:36 -07:00
|
|
|
// Parse the dimension operands and optional symbol operands, followed by a
|
|
|
|
// memref type.
|
2018-07-30 13:08:05 -07:00
|
|
|
unsigned numDimOperands;
|
2018-08-07 09:12:35 -07:00
|
|
|
if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type))
|
|
|
|
return true;
|
2018-07-30 13:08:05 -07:00
|
|
|
|
|
|
|
// Check numDynamicDims against number of question marks in memref type.
|
2018-09-20 09:39:55 -07:00
|
|
|
// Note: this check remains here (instead of in verify()), because the
|
|
|
|
// partition between dim operands and symbol operands is lost after parsing.
|
|
|
|
// Verification still checks that the total number of operands matches
|
|
|
|
// the number of symbols in the affine map, plus the number of dynamic
|
|
|
|
// dimensions in the memref.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (numDimOperands != type.getNumDynamicDims()) {
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"dimension operand count does not equal memref "
|
|
|
|
"dynamic dimension count");
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
2018-08-07 09:12:35 -07:00
|
|
|
result->types.push_back(type);
|
|
|
|
return false;
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool AllocOp::verify() const {
|
2018-10-30 14:59:22 -07:00
|
|
|
auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
|
2018-09-20 09:39:55 -07:00
|
|
|
if (!memRefType)
|
|
|
|
return emitOpError("result must be a memref");
|
|
|
|
|
|
|
|
unsigned numSymbols = 0;
|
2018-10-30 14:59:22 -07:00
|
|
|
if (!memRefType.getAffineMaps().empty()) {
|
|
|
|
AffineMap affineMap = memRefType.getAffineMaps()[0];
|
2018-09-20 09:39:55 -07:00
|
|
|
// Store number of symbols used in affine map (used in subsequent check).
|
2018-10-09 16:39:24 -07:00
|
|
|
numSymbols = affineMap.getNumSymbols();
|
2018-10-26 06:15:38 -07:00
|
|
|
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
|
|
|
|
// to the type system itself. It has been partially hoisted to Parser but
|
|
|
|
// remains here in case an AllocOp gets constructed programmatically.
|
|
|
|
// Remove when we can emit errors directly from *Type::get(...) functions.
|
|
|
|
//
|
2018-09-20 09:39:55 -07:00
|
|
|
// Verify that the layout affine map matches the rank of the memref.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (affineMap.getNumDims() != memRefType.getRank())
|
2018-09-20 09:39:55 -07:00
|
|
|
return emitOpError("affine map dimension count must equal memref rank");
|
|
|
|
}
|
2018-10-30 14:59:22 -07:00
|
|
|
unsigned numDynamicDims = memRefType.getNumDynamicDims();
|
2018-09-20 09:39:55 -07:00
|
|
|
// Check that the total number of operands matches the number of symbols in
|
|
|
|
// the affine map, plus the number of dynamic dimensions specified in the
|
|
|
|
// memref type.
|
2018-12-28 04:14:52 -08:00
|
|
|
if (getInstruction()->getNumOperands() != numDynamicDims + numSymbols) {
|
2018-09-20 09:39:55 -07:00
|
|
|
return emitOpError(
|
|
|
|
"operand count does not equal dimension plus symbol operand count");
|
|
|
|
}
|
2018-10-06 17:21:53 -07:00
|
|
|
// Verify that all operands are of type Index.
|
2018-09-20 09:39:55 -07:00
|
|
|
for (auto *operand : getOperands()) {
|
2018-10-30 14:59:22 -07:00
|
|
|
if (!operand->getType().isIndex())
|
2018-10-06 17:21:53 -07:00
|
|
|
return emitOpError("requires operands to be of type Index");
|
2018-09-20 09:39:55 -07:00
|
|
|
}
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
2018-10-25 16:44:04 -07:00
|
|
|
namespace {
|
|
|
|
/// Fold constant dimensions into an alloc instruction.
|
2018-11-28 15:09:39 -08:00
|
|
|
struct SimplifyAllocConst : public RewritePattern {
|
2018-10-25 16:44:04 -07:00
|
|
|
SimplifyAllocConst(MLIRContext *context)
|
2018-11-28 15:09:39 -08:00
|
|
|
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
|
2018-10-25 16:44:04 -07:00
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
PatternMatchResult match(OperationInst *op) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
auto alloc = op->cast<AllocOp>();
|
|
|
|
|
|
|
|
// Check to see if any dimensions operands are constants. If so, we can
|
|
|
|
// substitute and drop them.
|
|
|
|
for (auto *operand : alloc->getOperands())
|
2018-10-30 10:57:50 -07:00
|
|
|
if (matchPattern(operand, m_ConstantIndex()))
|
2018-10-29 10:22:49 -07:00
|
|
|
return matchSuccess();
|
2018-10-25 16:44:04 -07:00
|
|
|
return matchFailure();
|
|
|
|
}
|
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
auto allocOp = op->cast<AllocOp>();
|
|
|
|
auto memrefType = allocOp->getType();
|
|
|
|
|
|
|
|
// Ok, we have one or more constant operands. Collect the non-constant ones
|
|
|
|
// and keep track of the resultant memref type to build.
|
|
|
|
SmallVector<int, 4> newShapeConstants;
|
2018-10-30 14:59:22 -07:00
|
|
|
newShapeConstants.reserve(memrefType.getRank());
|
2018-12-27 14:35:10 -08:00
|
|
|
SmallVector<Value *, 4> newOperands;
|
|
|
|
SmallVector<Value *, 4> droppedOperands;
|
2018-10-25 16:44:04 -07:00
|
|
|
|
|
|
|
unsigned dynamicDimPos = 0;
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
|
|
|
|
int dimSize = memrefType.getDimSize(dim);
|
2018-10-25 16:44:04 -07:00
|
|
|
// If this is already static dimension, keep it.
|
|
|
|
if (dimSize != -1) {
|
|
|
|
newShapeConstants.push_back(dimSize);
|
|
|
|
continue;
|
|
|
|
}
|
2018-12-27 21:21:41 -08:00
|
|
|
auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst();
|
2018-10-25 16:44:04 -07:00
|
|
|
OpPointer<ConstantIndexOp> constantIndexOp;
|
|
|
|
if (defOp && (constantIndexOp = defOp->dyn_cast<ConstantIndexOp>())) {
|
|
|
|
// Dynamic shape dimension will be folded.
|
|
|
|
newShapeConstants.push_back(constantIndexOp->getValue());
|
|
|
|
// Record to check for zero uses later below.
|
|
|
|
droppedOperands.push_back(constantIndexOp);
|
|
|
|
} else {
|
|
|
|
// Dynamic shape dimension not folded; copy operand from old memref.
|
|
|
|
newShapeConstants.push_back(-1);
|
|
|
|
newOperands.push_back(allocOp->getOperand(dynamicDimPos));
|
|
|
|
}
|
|
|
|
dynamicDimPos++;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create new memref type (which will have fewer dynamic dimensions).
|
2018-10-30 14:59:22 -07:00
|
|
|
auto newMemRefType = MemRefType::get(
|
|
|
|
newShapeConstants, memrefType.getElementType(),
|
|
|
|
memrefType.getAffineMaps(), memrefType.getMemorySpace());
|
|
|
|
assert(newOperands.size() == newMemRefType.getNumDynamicDims());
|
2018-10-25 16:44:04 -07:00
|
|
|
|
|
|
|
// Create and insert the alloc op for the new memref.
|
|
|
|
auto newAlloc =
|
|
|
|
rewriter.create<AllocOp>(allocOp->getLoc(), newMemRefType, newOperands);
|
|
|
|
// Insert a cast so we have the same type as the old alloc.
|
|
|
|
auto resultCast = rewriter.create<MemRefCastOp>(allocOp->getLoc(), newAlloc,
|
|
|
|
allocOp->getType());
|
|
|
|
|
2018-11-24 07:40:55 -08:00
|
|
|
rewriter.replaceOp(op, {resultCast}, droppedOperands);
|
2018-10-25 16:44:04 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace.
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
results.push_back(std::make_unique<SimplifyAllocConst>(context));
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-08-21 17:55:22 -07:00
|
|
|
// CallOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-22 19:25:49 -07:00
|
|
|
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
2018-12-27 14:35:10 -08:00
|
|
|
ArrayRef<Value *> operands) {
|
2018-08-22 19:25:49 -07:00
|
|
|
result->addOperands(operands);
|
|
|
|
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
2018-10-30 14:59:22 -07:00
|
|
|
result->addTypes(callee->getType().getResults());
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
StringRef calleeName;
|
|
|
|
llvm::SMLoc calleeLoc;
|
2018-10-30 14:59:22 -07:00
|
|
|
FunctionType calleeType;
|
2018-08-21 17:55:22 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
|
|
|
Function *callee = nullptr;
|
|
|
|
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
|
|
|
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
|
|
|
OpAsmParser::Delimiter::Paren) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(calleeType) ||
|
|
|
|
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
2018-10-30 14:59:22 -07:00
|
|
|
parser->addTypesToList(calleeType.getResults(), result->types) ||
|
|
|
|
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
|
2018-08-21 17:55:22 -07:00
|
|
|
result->operands))
|
|
|
|
return true;
|
|
|
|
|
2018-08-22 19:25:49 -07:00
|
|
|
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
|
2018-08-21 17:55:22 -07:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
void CallOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "call ";
|
|
|
|
p->printFunctionReference(getCallee());
|
|
|
|
*p << '(';
|
|
|
|
p->printOperands(getOperands());
|
|
|
|
*p << ')';
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getCallee()->getType();
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool CallOp::verify() const {
|
2018-08-21 17:55:22 -07:00
|
|
|
// Check that the callee attribute was specified.
|
2018-10-25 15:46:10 -07:00
|
|
|
auto fnAttr = getAttrOfType<FunctionAttr>("callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
if (!fnAttr)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires a 'callee' function attribute");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
// Verify that the operand and result types match the callee.
|
2018-10-30 14:59:22 -07:00
|
|
|
auto fnType = fnAttr.getValue()->getType();
|
|
|
|
if (fnType.getNumInputs() != getNumOperands())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of operands for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
|
|
|
if (getOperand(i)->getType() != fnType.getInput(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("operand type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (fnType.getNumResults() != getNumResults())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of results for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
|
|
|
if (getResult(i)->getType() != fnType.getResult(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CallIndirectOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-22 19:25:49 -07:00
|
|
|
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *callee, ArrayRef<Value *> operands) {
|
2018-10-30 14:59:22 -07:00
|
|
|
auto fnType = callee->getType().cast<FunctionType>();
|
2018-08-22 19:25:49 -07:00
|
|
|
result->operands.push_back(callee);
|
|
|
|
result->addOperands(operands);
|
2018-10-30 14:59:22 -07:00
|
|
|
result->addTypes(fnType.getResults());
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-10-30 14:59:22 -07:00
|
|
|
FunctionType calleeType;
|
2018-08-21 17:55:22 -07:00
|
|
|
OpAsmParser::OperandType callee;
|
|
|
|
llvm::SMLoc operandsLoc;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
|
|
|
return parser->parseOperand(callee) ||
|
|
|
|
parser->getCurrentLocation(&operandsLoc) ||
|
|
|
|
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
|
|
|
OpAsmParser::Delimiter::Paren) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(calleeType) ||
|
|
|
|
parser->resolveOperand(callee, calleeType, result->operands) ||
|
2018-10-30 14:59:22 -07:00
|
|
|
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
|
2018-08-21 17:55:22 -07:00
|
|
|
result->operands) ||
|
2018-10-30 14:59:22 -07:00
|
|
|
parser->addTypesToList(calleeType.getResults(), result->types);
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
void CallIndirectOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "call_indirect ";
|
|
|
|
p->printOperand(getCallee());
|
|
|
|
*p << '(';
|
|
|
|
auto operandRange = getOperands();
|
|
|
|
p->printOperands(++operandRange.begin(), operandRange.end());
|
|
|
|
*p << ')';
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getCallee()->getType();
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool CallIndirectOp::verify() const {
|
2018-08-21 17:55:22 -07:00
|
|
|
// The callee must be a function.
|
2018-10-30 14:59:22 -07:00
|
|
|
auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
|
2018-08-21 17:55:22 -07:00
|
|
|
if (!fnType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("callee must have function type");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
// Verify that the operand and result types match the callee.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (fnType.getNumInputs() != getNumOperands() - 1)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of operands for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
|
|
|
if (getOperand(i + 1)->getType() != fnType.getInput(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("operand type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (fnType.getNumResults() != getNumResults())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of results for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
|
|
|
if (getResult(i)->getType() != fnType.getResult(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-11-08 04:02:00 -08:00
|
|
|
// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
|
|
|
static Type getI1SameShape(Builder *build, Type type) {
|
2018-11-28 11:49:26 -08:00
|
|
|
auto i1Type = build->getI1Type();
|
2018-12-05 04:31:59 -08:00
|
|
|
if (type.isIntOrIndexOrFloat())
|
2018-11-08 04:02:00 -08:00
|
|
|
return i1Type;
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
|
|
return build->getTensorType(tensorType.getShape(), i1Type);
|
|
|
|
if (auto tensorType = type.dyn_cast<UnrankedTensorType>())
|
|
|
|
return build->getTensorType(i1Type);
|
|
|
|
if (auto vectorType = type.dyn_cast<VectorType>())
|
|
|
|
return build->getVectorType(vectorType.getShape(), i1Type);
|
|
|
|
|
|
|
|
llvm_unreachable("unsupported type");
|
|
|
|
}
|
|
|
|
|
|
|
|
static inline bool isI1(Type type) {
|
|
|
|
return type.isa<IntegerType>() && type.cast<IntegerType>().getWidth() == 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename Ty>
|
|
|
|
static inline bool implCheckI1SameShape(Ty pattern, Type type) {
|
|
|
|
auto specificType = type.dyn_cast<Ty>();
|
|
|
|
if (!specificType)
|
|
|
|
return true;
|
|
|
|
if (specificType.getShape() != pattern.getShape())
|
|
|
|
return true;
|
|
|
|
return !isI1(specificType.getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
// Checks if "type" has the same shape (scalar, vector or tensor) as "pattern"
|
|
|
|
// and contains i1.
|
|
|
|
static bool checkI1SameShape(Type pattern, Type type) {
|
2018-12-05 04:31:59 -08:00
|
|
|
if (pattern.isIntOrIndexOrFloat())
|
2018-11-08 04:02:00 -08:00
|
|
|
return !isI1(type);
|
|
|
|
if (auto patternTensorType = pattern.dyn_cast<TensorType>())
|
|
|
|
return implCheckI1SameShape(patternTensorType, type);
|
|
|
|
if (auto patternVectorType = pattern.dyn_cast<VectorType>())
|
|
|
|
return implCheckI1SameShape(patternVectorType, type);
|
|
|
|
|
|
|
|
llvm_unreachable("unsupported type");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns an array of mnemonics for CmpIPredicates, indexed by values thereof.
|
|
|
|
static inline const char *const *getPredicateNames() {
|
|
|
|
static const char *predicateNames[(int)CmpIPredicate::NumPredicates]{
|
|
|
|
/*EQ*/ "eq",
|
|
|
|
/*NE*/ "ne",
|
|
|
|
/*SLT*/ "slt",
|
|
|
|
/*SLE*/ "sle",
|
|
|
|
/*SGT*/ "sgt",
|
|
|
|
/*SGE*/ "sge",
|
|
|
|
/*ULT*/ "ult",
|
|
|
|
/*ULE*/ "ule",
|
|
|
|
/*UGT*/ "ugt",
|
|
|
|
/*UGE*/ "uge"};
|
|
|
|
return predicateNames;
|
|
|
|
};
|
|
|
|
|
|
|
|
// Returns a value of the predicate corresponding to the given mnemonic.
|
|
|
|
// Returns NumPredicates (one-past-end) if there is no such mnemonic.
|
|
|
|
CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
|
|
|
|
return llvm::StringSwitch<CmpIPredicate>(name)
|
|
|
|
.Case("eq", CmpIPredicate::EQ)
|
|
|
|
.Case("ne", CmpIPredicate::NE)
|
|
|
|
.Case("slt", CmpIPredicate::SLT)
|
|
|
|
.Case("sle", CmpIPredicate::SLE)
|
|
|
|
.Case("sgt", CmpIPredicate::SGT)
|
|
|
|
.Case("sge", CmpIPredicate::SGE)
|
|
|
|
.Case("ult", CmpIPredicate::ULT)
|
|
|
|
.Case("ule", CmpIPredicate::ULE)
|
|
|
|
.Case("ugt", CmpIPredicate::UGT)
|
|
|
|
.Case("uge", CmpIPredicate::UGE)
|
|
|
|
.Default(CmpIPredicate::NumPredicates);
|
|
|
|
}
|
|
|
|
|
|
|
|
void CmpIOp::build(Builder *build, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
CmpIPredicate predicate, Value *lhs, Value *rhs) {
|
2018-11-08 04:02:00 -08:00
|
|
|
result->addOperands({lhs, rhs});
|
|
|
|
result->types.push_back(getI1SameShape(build, lhs->getType()));
|
|
|
|
result->addAttribute(getPredicateAttrName(),
|
2018-11-15 17:53:51 -08:00
|
|
|
build->getIntegerAttr(build->getIntegerType(64),
|
|
|
|
static_cast<int64_t>(predicate)));
|
2018-11-08 04:02:00 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> ops;
|
|
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
|
|
StringAttr predicateName;
|
|
|
|
Type type;
|
|
|
|
if (parser->parseAttribute(predicateName, getPredicateAttrName().data(),
|
|
|
|
attrs) ||
|
|
|
|
parser->parseComma() || parser->parseOperandList(ops, 2) ||
|
|
|
|
parser->parseOptionalAttributeDict(attrs) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperands(ops, type, result->operands))
|
|
|
|
return true;
|
|
|
|
|
|
|
|
// Rewrite string attribute to an enum value.
|
|
|
|
auto predicate = getPredicateByName(predicateName.getValue());
|
|
|
|
if (predicate == CmpIPredicate::NumPredicates)
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"unknown comparison predicate \"" +
|
|
|
|
Twine(predicateName.getValue()) + "\"");
|
2018-11-15 17:53:51 -08:00
|
|
|
auto builder = parser->getBuilder();
|
2018-12-26 11:48:58 -08:00
|
|
|
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
|
2018-11-08 04:02:00 -08:00
|
|
|
result->attributes = attrs;
|
|
|
|
|
2018-11-28 11:49:26 -08:00
|
|
|
result->addTypes({getI1SameShape(&builder, type)});
|
2018-11-08 04:02:00 -08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
void CmpIOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << getOperationName() << " ";
|
|
|
|
|
2018-11-12 06:33:22 -08:00
|
|
|
auto predicateValue =
|
|
|
|
getAttrOfType<IntegerAttr>(getPredicateAttrName()).getInt();
|
2018-11-08 04:02:00 -08:00
|
|
|
assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
|
|
|
|
predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
|
|
|
|
"unknown predicate index");
|
2018-12-28 04:14:52 -08:00
|
|
|
Builder b(getInstruction()->getContext());
|
2018-11-08 04:02:00 -08:00
|
|
|
auto predicateStringAttr =
|
|
|
|
b.getStringAttr(getPredicateNames()[predicateValue]);
|
|
|
|
p->printAttribute(predicateStringAttr);
|
|
|
|
|
|
|
|
*p << ", ";
|
|
|
|
p->printOperand(getOperand(0));
|
|
|
|
*p << ", ";
|
|
|
|
p->printOperand(getOperand(1));
|
|
|
|
p->printOptionalAttrDict(getAttrs(),
|
|
|
|
/*elidedAttrs=*/{getPredicateAttrName().data()});
|
|
|
|
*p << " : " << getOperand(0)->getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool CmpIOp::verify() const {
|
|
|
|
auto predicateAttr = getAttrOfType<IntegerAttr>(getPredicateAttrName());
|
|
|
|
if (!predicateAttr)
|
|
|
|
return emitOpError("requires an integer attribute named 'predicate'");
|
2018-11-12 06:33:22 -08:00
|
|
|
auto predicate = predicateAttr.getInt();
|
2018-11-08 04:02:00 -08:00
|
|
|
if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
|
|
|
|
predicate >= (int64_t)CmpIPredicate::NumPredicates)
|
|
|
|
return emitOpError("'predicate' attribute value out of range");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2019-01-06 14:09:15 -08:00
|
|
|
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
|
|
|
// comparison predicates.
|
|
|
|
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
|
|
|
const APInt &rhs) {
|
|
|
|
switch (predicate) {
|
|
|
|
case CmpIPredicate::EQ:
|
|
|
|
return lhs.eq(rhs);
|
|
|
|
case CmpIPredicate::NE:
|
|
|
|
return lhs.ne(rhs);
|
|
|
|
case CmpIPredicate::SLT:
|
|
|
|
return lhs.slt(rhs);
|
|
|
|
case CmpIPredicate::SLE:
|
|
|
|
return lhs.sle(rhs);
|
|
|
|
case CmpIPredicate::SGT:
|
|
|
|
return lhs.sgt(rhs);
|
|
|
|
case CmpIPredicate::SGE:
|
|
|
|
return lhs.sge(rhs);
|
|
|
|
case CmpIPredicate::ULT:
|
|
|
|
return lhs.ult(rhs);
|
|
|
|
case CmpIPredicate::ULE:
|
|
|
|
return lhs.ule(rhs);
|
|
|
|
case CmpIPredicate::UGT:
|
|
|
|
return lhs.ugt(rhs);
|
|
|
|
case CmpIPredicate::UGE:
|
|
|
|
return lhs.uge(rhs);
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unknown comparison predicate");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Constant folding hook for comparisons.
|
|
|
|
Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "cmpi takes two arguments");
|
|
|
|
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs || !rhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
|
|
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
|
|
|
|
}
|
|
|
|
|
2018-08-15 15:39:26 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DeallocOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-12-27 14:35:10 -08:00
|
|
|
void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) {
|
2018-08-31 14:49:38 -07:00
|
|
|
result->addOperands(memref);
|
|
|
|
}
|
|
|
|
|
2018-08-15 15:39:26 -07:00
|
|
|
void DeallocOp::print(OpAsmPrinter *p) const {
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
|
2018-08-15 15:39:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType memrefInfo;
|
2018-10-30 14:59:22 -07:00
|
|
|
MemRefType type;
|
2018-08-15 15:39:26 -07:00
|
|
|
|
|
|
|
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(memrefInfo, type, result->operands);
|
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool DeallocOp::verify() const {
|
2018-10-30 14:59:22 -07:00
|
|
|
if (!getMemRef()->getType().isa<MemRefType>())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("operand must be a memref");
|
|
|
|
return false;
|
2018-08-15 15:39:26 -07:00
|
|
|
}
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
/// dealloc(memrefcast) -> dealloc
|
|
|
|
results.push_back(
|
|
|
|
std::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DimOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-09-26 16:21:49 -07:00
|
|
|
void DimOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *memrefOrTensor, unsigned index) {
|
2018-09-26 16:21:49 -07:00
|
|
|
result->addOperands(memrefOrTensor);
|
2018-11-15 17:53:51 -08:00
|
|
|
auto type = builder->getIndexType();
|
|
|
|
result->addAttribute("index", builder->getIntegerAttr(type, index));
|
|
|
|
result->types.push_back(type);
|
2018-09-26 16:21:49 -07:00
|
|
|
}
|
|
|
|
|
2018-07-24 16:07:22 -07:00
|
|
|
void DimOp::print(OpAsmPrinter *p) const {
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << "dim " << *getOperand() << ", " << getIndex();
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getOperand()->getType();
|
2018-07-05 09:12:11 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
OpAsmParser::OperandType operandInfo;
|
2018-10-25 15:46:10 -07:00
|
|
|
IntegerAttr indexAttr;
|
2018-10-30 14:59:22 -07:00
|
|
|
Type type;
|
2018-11-15 17:53:51 -08:00
|
|
|
Type indexType = parser->getBuilder().getIndexType();
|
2018-08-02 16:54:36 -07:00
|
|
|
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
2018-11-15 17:53:51 -08:00
|
|
|
parser->parseAttribute(indexAttr, indexType, "index",
|
|
|
|
result->attributes) ||
|
2018-08-08 11:02:58 -07:00
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(operandInfo, type, result->operands) ||
|
2018-11-15 17:53:51 -08:00
|
|
|
parser->addTypeToList(indexType, result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool DimOp::verify() const {
|
2018-07-06 10:46:19 -07:00
|
|
|
// Check that we have an integer index operand.
|
|
|
|
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
|
|
|
if (!indexAttr)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires an integer attribute named 'index'");
|
2018-11-12 06:33:22 -08:00
|
|
|
uint64_t index = indexAttr.getValue().getZExtValue();
|
2018-07-24 08:34:58 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
auto type = getOperand()->getType();
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
|
|
if (index >= tensorType.getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("index is out of range");
|
2018-10-30 14:59:22 -07:00
|
|
|
} else if (auto memrefType = type.dyn_cast<MemRefType>()) {
|
|
|
|
if (index >= memrefType.getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("index is out of range");
|
2018-07-06 10:46:19 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
} else if (type.isa<UnrankedTensorType>()) {
|
2018-07-24 08:34:58 -07:00
|
|
|
// ok, assumed to be in-range.
|
|
|
|
} else {
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires an operand with tensor or memref type");
|
2018-07-24 08:34:58 -07:00
|
|
|
}
|
2018-07-06 10:46:19 -07:00
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-06 10:46:19 -07:00
|
|
|
}
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-05 18:24:18 -07:00
|
|
|
// Constant fold dim when the size along the index referred to is a constant.
|
2018-10-30 14:59:22 -07:00
|
|
|
auto opType = getOperand()->getType();
|
2018-10-05 09:28:49 -07:00
|
|
|
int indexSize = -1;
|
2018-10-30 14:59:22 -07:00
|
|
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
|
|
|
indexSize = tensorType.getShape()[getIndex()];
|
|
|
|
} else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
|
|
|
|
indexSize = memrefType.getShape()[getIndex()];
|
2018-10-05 09:28:49 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
if (indexSize >= 0)
|
2018-11-15 17:53:51 -08:00
|
|
|
return IntegerAttr::get(Type::getIndex(context), indexSize);
|
2018-10-05 09:28:49 -07:00
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2019-01-06 14:08:42 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DivISOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
(void)context;
|
|
|
|
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs || !rhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
// Don't fold if it requires division by zero.
|
|
|
|
if (rhs.getValue().isNullValue()) {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
// Don't fold if it would overflow.
|
|
|
|
bool overflow;
|
|
|
|
auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow);
|
|
|
|
return overflow ? IntegerAttr{} : IntegerAttr::get(lhs.getType(), result);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DivIUOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute DivIUOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
(void)context;
|
|
|
|
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs || !rhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
// Don't fold if it requires division by zero.
|
|
|
|
if (rhs.getValue().isNullValue()) {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhs.getValue()));
|
|
|
|
}
|
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// DmaStartOp
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
2018-11-08 17:31:01 -08:00
|
|
|
void DmaStartOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *srcMemRef, ArrayRef<Value *> srcIndices,
|
|
|
|
Value *destMemRef, ArrayRef<Value *> destIndices,
|
|
|
|
Value *numElements, Value *tagMemRef,
|
|
|
|
ArrayRef<Value *> tagIndices, Value *stride,
|
|
|
|
Value *elementsPerStride) {
|
2018-11-08 17:31:01 -08:00
|
|
|
result->addOperands(srcMemRef);
|
|
|
|
result->addOperands(srcIndices);
|
|
|
|
result->addOperands(destMemRef);
|
|
|
|
result->addOperands(destIndices);
|
|
|
|
result->addOperands(numElements);
|
|
|
|
result->addOperands(tagMemRef);
|
|
|
|
result->addOperands(tagIndices);
|
2018-12-05 15:30:25 -08:00
|
|
|
if (stride) {
|
|
|
|
result->addOperands(stride);
|
|
|
|
result->addOperands(elementsPerStride);
|
|
|
|
}
|
2018-11-08 17:31:01 -08:00
|
|
|
}
|
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
void DmaStartOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << getOperationName() << ' ' << *getSrcMemRef() << '[';
|
|
|
|
p->printOperands(getSrcIndices());
|
|
|
|
*p << "], " << *getDstMemRef() << '[';
|
|
|
|
p->printOperands(getDstIndices());
|
|
|
|
*p << "], " << *getNumElements();
|
|
|
|
*p << ", " << *getTagMemRef() << '[';
|
|
|
|
p->printOperands(getTagIndices());
|
|
|
|
*p << ']';
|
2018-12-05 15:30:25 -08:00
|
|
|
if (isStrided()) {
|
|
|
|
*p << ", " << *getStride();
|
|
|
|
*p << ", " << *getNumElementsPerStride();
|
|
|
|
}
|
2018-10-09 15:04:27 -07:00
|
|
|
p->printOptionalAttrDict(getAttrs());
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getSrcMemRef()->getType();
|
|
|
|
*p << ", " << getDstMemRef()->getType();
|
|
|
|
*p << ", " << getTagMemRef()->getType();
|
2018-10-09 15:04:27 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Parse DmaStartOp.
|
2018-10-18 11:14:26 -07:00
|
|
|
// Ex:
|
2018-10-09 15:04:27 -07:00
|
|
|
// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
|
|
|
|
// %tag[%index] :
|
|
|
|
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
|
|
|
|
// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>,
|
|
|
|
// memref<1 x i32, (d0) -> (d0), 4>
|
|
|
|
//
|
|
|
|
bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType srcMemRefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
|
|
|
|
OpAsmParser::OperandType dstMemRefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
|
|
|
|
OpAsmParser::OperandType numElementsInfo;
|
|
|
|
OpAsmParser::OperandType tagMemrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
|
2018-12-05 15:30:25 -08:00
|
|
|
SmallVector<OpAsmParser::OperandType, 2> strideInfo;
|
2018-10-09 15:04:27 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
SmallVector<Type, 3> types;
|
|
|
|
auto indexType = parser->getBuilder().getIndexType();
|
2018-10-09 15:04:27 -07:00
|
|
|
|
|
|
|
// Parse and resolve the following list of operands:
|
|
|
|
// *) source memref followed by its indices (in square brackets).
|
|
|
|
// *) destination memref followed by its indices (in square brackets).
|
|
|
|
// *) dma size in KiB.
|
|
|
|
if (parser->parseOperand(srcMemRefInfo) ||
|
|
|
|
parser->parseOperandList(srcIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
|
|
|
|
parser->parseOperandList(dstIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
|
|
|
|
parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
|
|
|
|
parser->parseOperandList(tagIndexInfos, -1,
|
2018-12-05 15:30:25 -08:00
|
|
|
OpAsmParser::Delimiter::Square))
|
|
|
|
return true;
|
|
|
|
|
|
|
|
// Parse optional stride and elements per stride.
|
|
|
|
if (parser->parseTrailingOperandList(strideInfo)) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
if (!strideInfo.empty() && strideInfo.size() != 2) {
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"expected two stride related operands");
|
|
|
|
}
|
|
|
|
bool isStrided = strideInfo.size() == 2;
|
|
|
|
|
|
|
|
if (parser->parseColonTypeList(types))
|
2018-10-09 15:04:27 -07:00
|
|
|
return true;
|
|
|
|
|
|
|
|
if (types.size() != 3)
|
|
|
|
return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
|
|
|
|
|
|
|
|
if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
|
|
|
|
parser->resolveOperands(srcIndexInfos, indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
|
|
|
|
parser->resolveOperands(dstIndexInfos, indexType, result->operands) ||
|
|
|
|
// size should be an index.
|
|
|
|
parser->resolveOperand(numElementsInfo, indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
|
|
|
|
// tag indices should be index.
|
|
|
|
parser->resolveOperands(tagIndexInfos, indexType, result->operands))
|
|
|
|
return true;
|
|
|
|
|
2018-12-05 15:30:25 -08:00
|
|
|
if (isStrided) {
|
|
|
|
if (parser->resolveOperand(strideInfo[0], indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(strideInfo[1], indexType, result->operands))
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
// Check that source/destination index list size matches associated rank.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
|
|
|
|
dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
|
2018-10-09 15:04:27 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"memref rank not equal to indices count");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
|
2018-10-09 15:04:27 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"tag memref rank not equal to indices count");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-12-05 15:30:25 -08:00
|
|
|
bool DmaStartOp::verify() const {
|
|
|
|
// DMAs from different memory spaces supported.
|
|
|
|
if (getSrcMemorySpace() == getDstMemorySpace()) {
|
|
|
|
return emitOpError("DMA should be between different memory spaces");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
|
|
|
|
getDstMemRefRank() + 3 + 1 &&
|
|
|
|
getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
|
|
|
|
getDstMemRefRank() + 3 + 1 + 2) {
|
|
|
|
return emitOpError("incorrect number of operands");
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
/// dma_start(memrefcast) -> dma_start
|
|
|
|
results.push_back(
|
|
|
|
std::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
|
|
}
|
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// DmaWaitOp
|
|
|
|
// ---------------------------------------------------------------------------
|
2018-10-18 11:14:26 -07:00
|
|
|
|
2018-11-08 17:31:01 -08:00
|
|
|
void DmaWaitOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *tagMemRef, ArrayRef<Value *> tagIndices,
|
|
|
|
Value *numElements) {
|
2018-11-08 17:31:01 -08:00
|
|
|
result->addOperands(tagMemRef);
|
|
|
|
result->addOperands(tagIndices);
|
|
|
|
result->addOperands(numElements);
|
|
|
|
}
|
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
void DmaWaitOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << getOperationName() << ' ';
|
|
|
|
// Print operands.
|
|
|
|
p->printOperand(getTagMemRef());
|
|
|
|
*p << '[';
|
|
|
|
p->printOperands(getTagIndices());
|
2018-10-18 11:14:26 -07:00
|
|
|
*p << "], ";
|
|
|
|
p->printOperand(getNumElements());
|
2018-11-16 20:12:06 -08:00
|
|
|
p->printOptionalAttrDict(getAttrs());
|
2018-12-10 15:17:25 -08:00
|
|
|
*p << " : " << getTagMemRef()->getType();
|
2018-10-09 15:04:27 -07:00
|
|
|
}
|
|
|
|
|
2018-10-18 11:14:26 -07:00
|
|
|
// Parse DmaWaitOp.
|
|
|
|
// Eg:
|
|
|
|
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
|
|
|
|
//
|
2018-10-09 15:04:27 -07:00
|
|
|
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType tagMemrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
2018-10-30 14:59:22 -07:00
|
|
|
Type type;
|
|
|
|
auto indexType = parser->getBuilder().getIndexType();
|
2018-10-18 11:14:26 -07:00
|
|
|
OpAsmParser::OperandType numElementsInfo;
|
2018-10-09 15:04:27 -07:00
|
|
|
|
2018-10-18 11:14:26 -07:00
|
|
|
// Parse tag memref, its indices, and dma size.
|
2018-10-09 15:04:27 -07:00
|
|
|
if (parser->parseOperand(tagMemrefInfo) ||
|
|
|
|
parser->parseOperandList(tagIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
2018-10-18 11:14:26 -07:00
|
|
|
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
|
2018-10-09 15:04:27 -07:00
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
|
2018-10-18 11:14:26 -07:00
|
|
|
parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
2018-10-09 15:04:27 -07:00
|
|
|
return true;
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
|
2018-10-09 15:04:27 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"tag memref rank not equal to indices count");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
/// dma_wait(memrefcast) -> dma_wait
|
|
|
|
results.push_back(
|
|
|
|
std::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
|
|
}
|
|
|
|
|
2018-08-23 09:58:23 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ExtractElementOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ExtractElementOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *aggregate, ArrayRef<Value *> indices) {
|
2018-10-30 14:59:22 -07:00
|
|
|
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
|
2018-08-23 09:58:23 -07:00
|
|
|
result->addOperands(aggregate);
|
|
|
|
result->addOperands(indices);
|
2018-10-30 14:59:22 -07:00
|
|
|
result->types.push_back(aggregateType.getElementType());
|
2018-08-23 09:58:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
void ExtractElementOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "extract_element " << *getAggregate() << '[';
|
|
|
|
p->printOperands(getIndices());
|
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getAggregate()->getType();
|
2018-08-23 09:58:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType aggregateInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
2018-10-30 14:59:22 -07:00
|
|
|
VectorOrTensorType type;
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-10-06 17:21:53 -07:00
|
|
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
2018-08-23 09:58:23 -07:00
|
|
|
return parser->parseOperand(aggregateInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(aggregateInfo, type, result->operands) ||
|
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
2018-10-30 14:59:22 -07:00
|
|
|
parser->addTypeToList(type.getElementType(), result->types);
|
2018-08-23 09:58:23 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool ExtractElementOp::verify() const {
|
2018-08-23 09:58:23 -07:00
|
|
|
if (getNumOperands() == 0)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("expected an aggregate to index into");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
|
2018-08-23 09:58:23 -07:00
|
|
|
if (!aggregateType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("first operand must be a vector or tensor");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (getType() != aggregateType.getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type must match element type of aggregate");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
|
|
|
for (auto *idx : getIndices())
|
2018-10-30 14:59:22 -07:00
|
|
|
if (!idx->getType().isIndex())
|
2018-10-06 17:21:53 -07:00
|
|
|
return emitOpError("index to extract_element must have 'index' type");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
|
|
|
// Verify the # indices match if we have a ranked type.
|
2018-10-30 14:59:22 -07:00
|
|
|
auto aggregateRank = aggregateType.getRank();
|
2018-08-23 09:58:23 -07:00
|
|
|
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of indices for extract_element");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-08-23 09:58:23 -07:00
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LoadOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-12-27 14:35:10 -08:00
|
|
|
void LoadOp::build(Builder *builder, OperationState *result, Value *memref,
|
|
|
|
ArrayRef<Value *> indices) {
|
2018-10-30 14:59:22 -07:00
|
|
|
auto memrefType = memref->getType().cast<MemRefType>();
|
2018-08-23 09:58:23 -07:00
|
|
|
result->addOperands(memref);
|
|
|
|
result->addOperands(indices);
|
2018-10-30 14:59:22 -07:00
|
|
|
result->types.push_back(memrefType.getElementType());
|
2018-08-23 09:58:23 -07:00
|
|
|
}
|
|
|
|
|
2018-07-25 11:15:20 -07:00
|
|
|
void LoadOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "load " << *getMemRef() << '[';
|
|
|
|
p->printOperands(getIndices());
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getMemRefType();
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
2018-10-30 14:59:22 -07:00
|
|
|
MemRefType type;
|
2018-07-25 11:15:20 -07:00
|
|
|
|
2018-10-06 17:21:53 -07:00
|
|
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperand(memrefInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(memrefInfo, type, result->operands) ||
|
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
2018-10-30 14:59:22 -07:00
|
|
|
parser->addTypeToList(type.getElementType(), result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool LoadOp::verify() const {
|
2018-07-28 09:36:25 -07:00
|
|
|
if (getNumOperands() == 0)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("expected a memref to load from");
|
2018-07-24 10:13:31 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
|
2018-07-28 09:36:25 -07:00
|
|
|
if (!memRefType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("first operand must be a memref");
|
2018-07-24 10:13:31 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (getType() != memRefType.getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type must match element type of memref");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (memRefType.getRank() != getNumOperands() - 1)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of indices for load");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
for (auto *idx : getIndices())
|
2018-10-30 14:59:22 -07:00
|
|
|
if (!idx->getType().isIndex())
|
2018-10-06 17:21:53 -07:00
|
|
|
return emitOpError("index to load must have 'index' type");
|
2018-07-24 17:43:56 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
// TODO: Verify we have the right number of indices.
|
2018-07-24 17:43:56 -07:00
|
|
|
|
2018-12-28 08:48:09 -08:00
|
|
|
// TODO: in Function verify that the indices are parameters, IV's, or the
|
2018-07-28 09:36:25 -07:00
|
|
|
// result of an affine_apply.
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-24 10:13:31 -07:00
|
|
|
}
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
/// load(memrefcast) -> load
|
|
|
|
results.push_back(
|
|
|
|
std::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
|
|
}
|
|
|
|
|
2018-10-22 09:00:03 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MemRefCastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool MemRefCastOp::verify() const {
|
2018-10-30 14:59:22 -07:00
|
|
|
auto opType = getOperand()->getType().dyn_cast<MemRefType>();
|
|
|
|
auto resType = getType().dyn_cast<MemRefType>();
|
2018-10-22 09:00:03 -07:00
|
|
|
if (!opType || !resType)
|
|
|
|
return emitOpError("requires input and result types to be memrefs");
|
|
|
|
|
|
|
|
if (opType == resType)
|
|
|
|
return emitOpError("requires the input and result type to be different");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (opType.getElementType() != resType.getElementType())
|
2018-10-22 09:00:03 -07:00
|
|
|
return emitOpError(
|
|
|
|
"requires input and result element types to be the same");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (opType.getAffineMaps() != resType.getAffineMaps())
|
2018-10-22 09:00:03 -07:00
|
|
|
return emitOpError("requires input and result mappings to be the same");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (opType.getMemorySpace() != resType.getMemorySpace())
|
2018-10-22 09:00:03 -07:00
|
|
|
return emitOpError(
|
|
|
|
"requires input and result memory spaces to be the same");
|
|
|
|
|
|
|
|
// They must have the same rank, and any specified dimensions must match.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (opType.getRank() != resType.getRank())
|
2018-10-22 09:00:03 -07:00
|
|
|
return emitOpError("requires input and result ranks to match");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
|
|
|
|
int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
|
2018-10-22 09:00:03 -07:00
|
|
|
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
|
|
|
return emitOpError("requires static dimensions to match");
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-09-26 10:07:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MulFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-09-26 10:07:16 -07:00
|
|
|
assert(operands.size() == 2 && "mulf takes two operands");
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
2018-11-15 17:53:51 -08:00
|
|
|
if (lhs.getType() == rhs.getType())
|
|
|
|
return FloatAttr::get(lhs.getType(), lhs.getValue() * rhs.getValue());
|
2018-09-26 10:07:16 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MulIOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-03 09:43:13 -07:00
|
|
|
assert(operands.size() == 2 && "muli takes two operands");
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
2018-10-05 09:28:49 -07:00
|
|
|
// 0*x == 0
|
2018-10-25 15:46:10 -07:00
|
|
|
if (lhs.getValue() == 0)
|
2018-10-05 09:28:49 -07:00
|
|
|
return lhs;
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
2018-10-05 09:28:49 -07:00
|
|
|
// TODO: Handle the overflow case.
|
2018-11-15 17:53:51 -08:00
|
|
|
if (lhs.getType() == rhs.getType())
|
|
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue() * rhs.getValue());
|
2018-10-03 09:43:13 -07:00
|
|
|
}
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
// x*0 == 0
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
|
|
|
if (rhs.getValue() == 0)
|
2018-10-05 09:28:49 -07:00
|
|
|
return rhs;
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-26 11:28:06 -07:00
|
|
|
namespace {
|
|
|
|
/// muli(x, 1) -> x
|
|
|
|
///
|
2018-11-28 15:09:39 -08:00
|
|
|
struct SimplifyMulX1 : public RewritePattern {
|
2018-10-26 11:28:06 -07:00
|
|
|
SimplifyMulX1(MLIRContext *context)
|
2018-11-28 15:09:39 -08:00
|
|
|
: RewritePattern(MulIOp::getOperationName(), 1, context) {}
|
2018-10-26 11:28:06 -07:00
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
PatternMatchResult match(OperationInst *op) const override {
|
2018-10-26 11:28:06 -07:00
|
|
|
auto muli = op->cast<MulIOp>();
|
2018-10-29 10:22:49 -07:00
|
|
|
|
2018-10-30 10:57:50 -07:00
|
|
|
if (matchPattern(muli->getOperand(1), m_One()))
|
2018-10-29 10:22:49 -07:00
|
|
|
return matchSuccess();
|
2018-10-26 11:28:06 -07:00
|
|
|
|
|
|
|
return matchFailure();
|
|
|
|
}
|
2018-12-27 21:21:41 -08:00
|
|
|
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
2018-11-24 07:40:55 -08:00
|
|
|
rewriter.replaceOp(op, op->getOperand(0));
|
2018-10-26 11:28:06 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace.
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-26 11:28:06 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
results.push_back(std::make_unique<SimplifyMulX1>(context));
|
|
|
|
}
|
|
|
|
|
2019-01-06 14:08:42 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// RemISOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "remis takes two operands");
|
|
|
|
|
|
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!rhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
// x % 1 = 0
|
|
|
|
if (rhs.getValue().isOneValue())
|
|
|
|
return IntegerAttr::get(rhs.getType(),
|
|
|
|
APInt(rhs.getValue().getBitWidth(), 0));
|
|
|
|
|
|
|
|
// Don't fold if it requires division by zero.
|
|
|
|
if (rhs.getValue().isNullValue()) {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhs.getValue()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// RemIUOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "remiu takes two operands");
|
|
|
|
|
|
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!rhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
// x % 1 = 0
|
|
|
|
if (rhs.getValue().isOneValue())
|
|
|
|
return IntegerAttr::get(rhs.getType(),
|
|
|
|
APInt(rhs.getValue().getBitWidth(), 0));
|
|
|
|
|
|
|
|
// Don't fold if it requires division by zero.
|
|
|
|
if (rhs.getValue().isNullValue()) {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue()));
|
|
|
|
}
|
|
|
|
|
2018-11-28 07:08:55 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SelectOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-11-28 15:09:39 -08:00
|
|
|
|
2018-12-27 14:35:10 -08:00
|
|
|
void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
|
|
|
|
Value *trueValue, Value *falseValue) {
|
2018-11-28 07:08:55 -08:00
|
|
|
result->addOperands({condition, trueValue, falseValue});
|
|
|
|
result->addTypes(trueValue->getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 3> ops;
|
|
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
|
|
Type type;
|
|
|
|
|
|
|
|
if (parser->parseOperandList(ops, 3) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type))
|
|
|
|
return true;
|
|
|
|
|
|
|
|
auto i1Type = getI1SameShape(&parser->getBuilder(), type);
|
|
|
|
SmallVector<Type, 3> types = {i1Type, type, type};
|
|
|
|
return parser->resolveOperands(ops, types, parser->getNameLoc(),
|
|
|
|
result->operands) ||
|
|
|
|
parser->addTypeToList(type, result->types);
|
|
|
|
}
|
|
|
|
|
|
|
|
void SelectOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << getOperationName() << ' ';
|
2018-12-28 04:14:52 -08:00
|
|
|
p->printOperands(getInstruction()->getOperands());
|
2018-11-28 07:08:55 -08:00
|
|
|
*p << " : " << getTrueValue()->getType();
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
}
|
|
|
|
|
|
|
|
bool SelectOp::verify() const {
|
|
|
|
auto conditionType = getCondition()->getType();
|
|
|
|
auto trueType = getTrueValue()->getType();
|
|
|
|
auto falseType = getFalseValue()->getType();
|
|
|
|
|
|
|
|
if (trueType != falseType)
|
|
|
|
return emitOpError(
|
|
|
|
"requires 'true' and 'false' arguments to be of the same type");
|
|
|
|
|
|
|
|
if (checkI1SameShape(trueType, conditionType))
|
|
|
|
return emitOpError("requires the condition to have the same shape as "
|
|
|
|
"arguments with elemental type i1");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 3 && "select takes three operands");
|
|
|
|
|
|
|
|
// select true, %0, %1 => %0
|
|
|
|
// select false, %0, %1 => %1
|
|
|
|
auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!cond)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
if (cond.getValue().isNullValue()) {
|
|
|
|
return operands[2];
|
|
|
|
} else if (cond.getValue().isOneValue()) {
|
|
|
|
return operands[1];
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm_unreachable("first argument of select must be i1");
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StoreOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-31 14:49:38 -07:00
|
|
|
void StoreOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *valueToStore, Value *memref,
|
|
|
|
ArrayRef<Value *> indices) {
|
2018-08-31 14:49:38 -07:00
|
|
|
result->addOperands(valueToStore);
|
|
|
|
result->addOperands(memref);
|
|
|
|
result->addOperands(indices);
|
|
|
|
}
|
|
|
|
|
2018-07-31 14:11:38 -07:00
|
|
|
void StoreOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "store " << *getValueToStore();
|
|
|
|
*p << ", " << *getMemRef() << '[';
|
|
|
|
p->printOperands(getIndices());
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
2018-10-30 14:59:22 -07:00
|
|
|
*p << " : " << getMemRefType();
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-31 14:11:38 -07:00
|
|
|
OpAsmParser::OperandType storeValueInfo;
|
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
2018-10-30 14:59:22 -07:00
|
|
|
MemRefType memrefType;
|
2018-07-31 14:11:38 -07:00
|
|
|
|
2018-10-06 17:21:53 -07:00
|
|
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
|
|
|
parser->parseOperand(memrefInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(memrefType) ||
|
2018-10-30 14:59:22 -07:00
|
|
|
parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
|
2018-08-08 11:02:58 -07:00
|
|
|
result->operands) ||
|
|
|
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
2018-08-07 09:12:35 -07:00
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool StoreOp::verify() const {
|
2018-07-31 14:11:38 -07:00
|
|
|
if (getNumOperands() < 2)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("expected a value to store and a memref");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
// Second operand is a memref type.
|
2018-10-30 14:59:22 -07:00
|
|
|
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
|
2018-07-31 14:11:38 -07:00
|
|
|
if (!memRefType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("second operand must be a memref");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
// First operand must have same type as memref element type.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (getValueToStore()->getType() != memRefType.getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("first operand must have same type memref element type");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (getNumOperands() != 2 + memRefType.getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("store index operand count not equal to memref rank");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
for (auto *idx : getIndices())
|
2018-10-30 14:59:22 -07:00
|
|
|
if (!idx->getType().isIndex())
|
2018-10-06 17:21:53 -07:00
|
|
|
return emitOpError("index to load must have 'index' type");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
// TODO: Verify we have the right number of indices.
|
|
|
|
|
2018-12-28 08:48:09 -08:00
|
|
|
// TODO: in Function verify that the indices are parameters, IV's, or the
|
2018-07-31 14:11:38 -07:00
|
|
|
// result of an affine_apply.
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
/// store(memrefcast) -> store
|
|
|
|
results.push_back(
|
|
|
|
std::make_unique<MemRefCastFolder>(getOperationName(), context));
|
|
|
|
}
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SubFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-03 09:43:13 -07:00
|
|
|
assert(operands.size() == 2 && "subf takes two operands");
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
2018-11-15 17:53:51 -08:00
|
|
|
if (lhs.getType() == rhs.getType())
|
|
|
|
return FloatAttr::get(lhs.getType(), lhs.getValue() - rhs.getValue());
|
2018-10-03 09:43:13 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SubIOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-03 09:43:13 -07:00
|
|
|
assert(operands.size() == 2 && "subi takes two operands");
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
|
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
2018-11-15 17:53:51 -08:00
|
|
|
if (lhs.getType() == rhs.getType())
|
|
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue() - rhs.getValue());
|
2018-10-03 09:43:13 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
2018-10-25 16:44:04 -07:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// subi(x,x) -> 0
|
|
|
|
///
|
2018-11-28 15:09:39 -08:00
|
|
|
struct SimplifyXMinusX : public RewritePattern {
|
2018-10-25 16:44:04 -07:00
|
|
|
SimplifyXMinusX(MLIRContext *context)
|
2018-11-28 15:09:39 -08:00
|
|
|
: RewritePattern(SubIOp::getOperationName(), 1, context) {}
|
2018-10-25 16:44:04 -07:00
|
|
|
|
2018-12-27 21:21:41 -08:00
|
|
|
PatternMatchResult match(OperationInst *op) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
auto subi = op->cast<SubIOp>();
|
|
|
|
if (subi->getOperand(0) == subi->getOperand(1))
|
|
|
|
return matchSuccess();
|
|
|
|
|
|
|
|
return matchFailure();
|
|
|
|
}
|
2018-12-27 21:21:41 -08:00
|
|
|
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
2018-10-25 16:44:04 -07:00
|
|
|
auto subi = op->cast<SubIOp>();
|
|
|
|
auto result =
|
|
|
|
rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
|
|
|
|
|
2018-11-24 07:40:55 -08:00
|
|
|
rewriter.replaceOp(op, {result});
|
2018-10-25 16:44:04 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace.
|
|
|
|
|
2018-11-28 15:09:39 -08:00
|
|
|
void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
2018-10-25 16:44:04 -07:00
|
|
|
MLIRContext *context) {
|
|
|
|
results.push_back(std::make_unique<SimplifyXMinusX>(context));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TensorCastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool TensorCastOp::verify() const {
|
2018-10-30 14:59:22 -07:00
|
|
|
auto opType = getOperand()->getType().dyn_cast<TensorType>();
|
|
|
|
auto resType = getType().dyn_cast<TensorType>();
|
2018-10-25 16:44:04 -07:00
|
|
|
if (!opType || !resType)
|
|
|
|
return emitOpError("requires input and result types to be tensors");
|
|
|
|
|
|
|
|
if (opType == resType)
|
|
|
|
return emitOpError("requires the input and result type to be different");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
if (opType.getElementType() != resType.getElementType())
|
2018-10-25 16:44:04 -07:00
|
|
|
return emitOpError(
|
|
|
|
"requires input and result element types to be the same");
|
|
|
|
|
|
|
|
// If the source or destination are unranked, then the cast is valid.
|
2018-10-30 14:59:22 -07:00
|
|
|
auto opRType = opType.dyn_cast<RankedTensorType>();
|
|
|
|
auto resRType = resType.dyn_cast<RankedTensorType>();
|
2018-10-25 16:44:04 -07:00
|
|
|
if (!opRType || !resRType)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// If they are both ranked, they have to have the same rank, and any specified
|
|
|
|
// dimensions must match.
|
2018-10-30 14:59:22 -07:00
|
|
|
if (opRType.getRank() != resRType.getRank())
|
2018-10-25 16:44:04 -07:00
|
|
|
return emitOpError("requires input and result ranks to match");
|
|
|
|
|
2018-10-30 14:59:22 -07:00
|
|
|
for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
|
|
|
|
int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
|
2018-10-25 16:44:04 -07:00
|
|
|
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
|
|
|
return emitOpError("requires static dimensions to match");
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
[MLIR] Add VectorTransferOps
This CL implements and uses VectorTransferOps in lieu of the former custom
call op. Tests are updated accordingly.
VectorTransferOps come in 2 flavors: VectorTransferReadOp and
VectorTransferWriteOp.
VectorTransferOps can be thought of as a backend-independent
pseudo op/library call that needs to be legalized to MLIR (whiteboxed) before
it can be lowered to backend-dependent IR.
Note that the current implementation does not yet support a real permutation
map. Proper support will come in a followup CL.
VectorTransferReadOp
====================
VectorTransferReadOp performs a blocking read from a scalar memref
location into a super-vector of the same elemental type. This operation is
called 'read' by opposition to 'load' because the super-vector granularity
is generally not representable with a single hardware register. As a
consequence, memory transfers will generally be required when lowering
VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
that supports super-vectorization with non-effecting padding for full-tile
only code.
A vector transfer read has semantics similar to a vector load, with additional
support for:
1. an optional value of the elemental type of the MemRef. This value
supports non-effecting padding and is inserted in places where the
vector read exceeds the MemRef bounds. If the value is not specified,
the access is statically guaranteed to be within bounds;
2. an attribute of type AffineMap to specify a slice of the original
MemRef access and its transposition into the super-vector shape. The
permutation_map is an unbounded AffineMap that must represent a
permutation from the MemRef dim space projected onto the vector dim
space.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
...
%val = `ssa-value` : f32
// let %i, %j, %k, %l be ssa-values of type index
%v0 = vector_transfer_read %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index) ->
vector<16x32x64xf32>
%v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index, f32) ->
vector<16x32x64xf32>
```
VectorTransferWriteOp
=====================
VectorTransferWriteOp performs a blocking write from a super-vector to
a scalar memref of the same elemental type. This operation is
called 'write' by opposition to 'store' because the super-vector
granularity is generally not representable with a single hardware register. As
a consequence, memory transfers will generally be required when lowering
VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
abstraction that supports super-vectorization with non-effecting padding
for full-tile only code.
A vector transfer write has semantics similar to a vector store, with
additional support for handling out-of-bounds situations.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
%val = `ssa-value` : vector<16x32x64xf32>
// let %i, %j, %k, %l be ssa-values of type index
vector_transfer_write %val, %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index)
```
PiperOrigin-RevId: 223873234
2018-12-03 15:21:27 -08:00
|
|
|
|