2018-07-05 09:12:11 -07:00
|
|
|
//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
2018-10-10 14:23:30 -07:00
|
|
|
#include "mlir/StandardOps/StandardOps.h"
|
2018-09-24 10:23:02 -07:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
2018-07-24 10:13:31 -07:00
|
|
|
#include "mlir/IR/AffineMap.h"
|
2018-07-25 11:15:20 -07:00
|
|
|
#include "mlir/IR/Builders.h"
|
2018-10-10 14:23:30 -07:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2018-07-24 16:07:22 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2018-07-24 08:34:58 -07:00
|
|
|
#include "mlir/IR/SSAValue.h"
|
|
|
|
#include "mlir/IR/Types.h"
|
2018-10-03 10:07:54 -07:00
|
|
|
#include "mlir/Support/MathExtras.h"
|
2018-08-09 12:28:58 -07:00
|
|
|
#include "mlir/Support/STLExtras.h"
|
2018-07-05 09:12:11 -07:00
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
|
2018-10-21 19:49:31 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StandardOpsDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
|
|
|
: Dialect(/*opPrefix=*/"", context) {
|
|
|
|
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, DeallocOp,
|
|
|
|
DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp, MulFOp,
|
|
|
|
MulIOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>();
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AddFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-09-19 21:35:11 -07:00
|
|
|
Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "addf takes two operands");
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
|
|
|
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
2018-09-19 21:35:11 -07:00
|
|
|
return FloatAttr::get(lhs->getValue() + rhs->getValue(), context);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AddIOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "addi takes two operands");
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
|
|
|
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
2018-10-03 09:43:13 -07:00
|
|
|
return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AllocOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-31 14:49:38 -07:00
|
|
|
void AllocOp::build(Builder *builder, OperationState *result,
|
|
|
|
MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
|
|
|
|
result->addOperands(operands);
|
|
|
|
result->types.push_back(memrefType);
|
|
|
|
}
|
|
|
|
|
2018-07-30 13:08:05 -07:00
|
|
|
void AllocOp::print(OpAsmPrinter *p) const {
|
|
|
|
MemRefType *type = cast<MemRefType>(getMemRef()->getType());
|
|
|
|
*p << "alloc";
|
|
|
|
// Print dynamic dimension operands.
|
|
|
|
printDimAndSymbolList(operand_begin(), operand_end(),
|
|
|
|
type->getNumDynamicDims(), p);
|
2018-08-02 16:54:36 -07:00
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
2018-07-30 13:08:05 -07:00
|
|
|
*p << " : " << *type;
|
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-30 13:08:05 -07:00
|
|
|
MemRefType *type;
|
|
|
|
|
2018-07-31 16:21:36 -07:00
|
|
|
// Parse the dimension operands and optional symbol operands, followed by a
|
|
|
|
// memref type.
|
2018-07-30 13:08:05 -07:00
|
|
|
unsigned numDimOperands;
|
2018-08-07 09:12:35 -07:00
|
|
|
if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type))
|
|
|
|
return true;
|
2018-07-30 13:08:05 -07:00
|
|
|
|
|
|
|
// Check numDynamicDims against number of question marks in memref type.
|
2018-09-20 09:39:55 -07:00
|
|
|
// Note: this check remains here (instead of in verify()), because the
|
|
|
|
// partition between dim operands and symbol operands is lost after parsing.
|
|
|
|
// Verification still checks that the total number of operands matches
|
|
|
|
// the number of symbols in the affine map, plus the number of dynamic
|
|
|
|
// dimensions in the memref.
|
2018-07-30 13:08:05 -07:00
|
|
|
if (numDimOperands != type->getNumDynamicDims()) {
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"dimension operand count does not equal memref "
|
|
|
|
"dynamic dimension count");
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
2018-08-07 09:12:35 -07:00
|
|
|
result->types.push_back(type);
|
|
|
|
return false;
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool AllocOp::verify() const {
|
2018-09-20 09:39:55 -07:00
|
|
|
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
|
|
|
if (!memRefType)
|
|
|
|
return emitOpError("result must be a memref");
|
|
|
|
|
|
|
|
unsigned numSymbols = 0;
|
|
|
|
if (!memRefType->getAffineMaps().empty()) {
|
2018-10-09 16:39:24 -07:00
|
|
|
AffineMap affineMap = memRefType->getAffineMaps()[0];
|
2018-09-20 09:39:55 -07:00
|
|
|
// Store number of symbols used in affine map (used in subsequent check).
|
2018-10-09 16:39:24 -07:00
|
|
|
numSymbols = affineMap.getNumSymbols();
|
2018-09-20 09:39:55 -07:00
|
|
|
// Verify that the layout affine map matches the rank of the memref.
|
2018-10-09 16:39:24 -07:00
|
|
|
if (affineMap.getNumDims() != memRefType->getRank())
|
2018-09-20 09:39:55 -07:00
|
|
|
return emitOpError("affine map dimension count must equal memref rank");
|
|
|
|
}
|
|
|
|
unsigned numDynamicDims = memRefType->getNumDynamicDims();
|
|
|
|
// Check that the total number of operands matches the number of symbols in
|
|
|
|
// the affine map, plus the number of dynamic dimensions specified in the
|
|
|
|
// memref type.
|
|
|
|
if (getOperation()->getNumOperands() != numDynamicDims + numSymbols) {
|
|
|
|
return emitOpError(
|
|
|
|
"operand count does not equal dimension plus symbol operand count");
|
|
|
|
}
|
2018-10-06 17:21:53 -07:00
|
|
|
// Verify that all operands are of type Index.
|
2018-09-20 09:39:55 -07:00
|
|
|
for (auto *operand : getOperands()) {
|
2018-10-06 17:21:53 -07:00
|
|
|
if (!operand->getType()->isIndex())
|
|
|
|
return emitOpError("requires operands to be of type Index");
|
2018-09-20 09:39:55 -07:00
|
|
|
}
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-08-21 17:55:22 -07:00
|
|
|
// CallOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-22 19:25:49 -07:00
|
|
|
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
|
|
|
ArrayRef<SSAValue *> operands) {
|
|
|
|
result->addOperands(operands);
|
|
|
|
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
|
|
|
result->addTypes(callee->getType()->getResults());
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
StringRef calleeName;
|
|
|
|
llvm::SMLoc calleeLoc;
|
|
|
|
FunctionType *calleeType = nullptr;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
|
|
|
Function *callee = nullptr;
|
|
|
|
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
|
|
|
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
|
|
|
OpAsmParser::Delimiter::Paren) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(calleeType) ||
|
|
|
|
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
|
|
|
parser->addTypesToList(calleeType->getResults(), result->types) ||
|
|
|
|
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
|
|
|
|
result->operands))
|
|
|
|
return true;
|
|
|
|
|
2018-08-22 19:25:49 -07:00
|
|
|
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
|
2018-08-21 17:55:22 -07:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
void CallOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "call ";
|
|
|
|
p->printFunctionReference(getCallee());
|
|
|
|
*p << '(';
|
|
|
|
p->printOperands(getOperands());
|
|
|
|
*p << ')';
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
|
|
|
*p << " : " << *getCallee()->getType();
|
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool CallOp::verify() const {
|
2018-08-21 17:55:22 -07:00
|
|
|
// Check that the callee attribute was specified.
|
|
|
|
auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
|
|
|
|
if (!fnAttr)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires a 'callee' function attribute");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
// Verify that the operand and result types match the callee.
|
|
|
|
auto *fnType = fnAttr->getValue()->getType();
|
|
|
|
if (fnType->getNumInputs() != getNumOperands())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of operands for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
|
|
|
if (getOperand(i)->getType() != fnType->getInput(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("operand type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
if (fnType->getNumResults() != getNumResults())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of results for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
|
|
|
if (getResult(i)->getType() != fnType->getResult(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CallIndirectOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-22 19:25:49 -07:00
|
|
|
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *callee, ArrayRef<SSAValue *> operands) {
|
2018-08-21 17:55:22 -07:00
|
|
|
auto *fnType = cast<FunctionType>(callee->getType());
|
2018-08-22 19:25:49 -07:00
|
|
|
result->operands.push_back(callee);
|
|
|
|
result->addOperands(operands);
|
|
|
|
result->addTypes(fnType->getResults());
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
FunctionType *calleeType = nullptr;
|
|
|
|
OpAsmParser::OperandType callee;
|
|
|
|
llvm::SMLoc operandsLoc;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
|
|
|
return parser->parseOperand(callee) ||
|
|
|
|
parser->getCurrentLocation(&operandsLoc) ||
|
|
|
|
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
|
|
|
OpAsmParser::Delimiter::Paren) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(calleeType) ||
|
|
|
|
parser->resolveOperand(callee, calleeType, result->operands) ||
|
|
|
|
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
|
|
|
|
result->operands) ||
|
|
|
|
parser->addTypesToList(calleeType->getResults(), result->types);
|
|
|
|
}
|
|
|
|
|
|
|
|
void CallIndirectOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "call_indirect ";
|
|
|
|
p->printOperand(getCallee());
|
|
|
|
*p << '(';
|
|
|
|
auto operandRange = getOperands();
|
|
|
|
p->printOperands(++operandRange.begin(), operandRange.end());
|
|
|
|
*p << ')';
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
|
|
|
*p << " : " << *getCallee()->getType();
|
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool CallIndirectOp::verify() const {
|
2018-08-21 17:55:22 -07:00
|
|
|
// The callee must be a function.
|
|
|
|
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
|
|
|
|
if (!fnType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("callee must have function type");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
// Verify that the operand and result types match the callee.
|
|
|
|
if (fnType->getNumInputs() != getNumOperands() - 1)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of operands for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
|
|
|
if (getOperand(i + 1)->getType() != fnType->getInput(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("operand type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
if (fnType->getNumResults() != getNumResults())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of results for callee");
|
2018-08-21 17:55:22 -07:00
|
|
|
|
|
|
|
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
|
|
|
if (getResult(i)->getType() != fnType->getResult(i))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type mismatch");
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-08-21 17:55:22 -07:00
|
|
|
}
|
|
|
|
|
2018-08-15 15:39:26 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DeallocOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-31 14:49:38 -07:00
|
|
|
void DeallocOp::build(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *memref) {
|
|
|
|
result->addOperands(memref);
|
|
|
|
}
|
|
|
|
|
2018-08-15 15:39:26 -07:00
|
|
|
void DeallocOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
MemRefType *type;
|
|
|
|
|
|
|
|
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(memrefInfo, type, result->operands);
|
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool DeallocOp::verify() const {
|
2018-08-15 15:39:26 -07:00
|
|
|
if (!isa<MemRefType>(getMemRef()->getType()))
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("operand must be a memref");
|
|
|
|
return false;
|
2018-08-15 15:39:26 -07:00
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DimOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-09-26 16:21:49 -07:00
|
|
|
void DimOp::build(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *memrefOrTensor, unsigned index) {
|
|
|
|
result->addOperands(memrefOrTensor);
|
|
|
|
result->addAttribute("index", builder->getIntegerAttr(index));
|
2018-10-06 17:21:53 -07:00
|
|
|
result->types.push_back(builder->getIndexType());
|
2018-09-26 16:21:49 -07:00
|
|
|
}
|
|
|
|
|
2018-07-24 16:07:22 -07:00
|
|
|
void DimOp::print(OpAsmPrinter *p) const {
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << "dim " << *getOperand() << ", " << getIndex();
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
|
|
|
|
*p << " : " << *getOperand()->getType();
|
2018-07-05 09:12:11 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
OpAsmParser::OperandType operandInfo;
|
|
|
|
IntegerAttr *indexAttr;
|
|
|
|
Type *type;
|
2018-08-02 16:54:36 -07:00
|
|
|
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
|
|
|
parser->parseAttribute(indexAttr, "index", result->attributes) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(operandInfo, type, result->operands) ||
|
2018-10-06 17:21:53 -07:00
|
|
|
parser->addTypeToList(parser->getBuilder().getIndexType(),
|
2018-08-08 11:02:58 -07:00
|
|
|
result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool DimOp::verify() const {
|
2018-07-06 10:46:19 -07:00
|
|
|
// Check that we have an integer index operand.
|
|
|
|
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
|
|
|
if (!indexAttr)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires an integer attribute named 'index'");
|
2018-07-24 08:34:58 -07:00
|
|
|
uint64_t index = (uint64_t)indexAttr->getValue();
|
|
|
|
|
|
|
|
auto *type = getOperand()->getType();
|
|
|
|
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
|
|
|
if (index >= tensorType->getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("index is out of range");
|
2018-07-24 08:34:58 -07:00
|
|
|
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
|
|
|
|
if (index >= memrefType->getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("index is out of range");
|
2018-07-06 10:46:19 -07:00
|
|
|
|
2018-07-24 08:34:58 -07:00
|
|
|
} else if (isa<UnrankedTensorType>(type)) {
|
|
|
|
// ok, assumed to be in-range.
|
|
|
|
} else {
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires an operand with tensor or memref type");
|
2018-07-24 08:34:58 -07:00
|
|
|
}
|
2018-07-06 10:46:19 -07:00
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-06 10:46:19 -07:00
|
|
|
}
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-05 18:24:18 -07:00
|
|
|
// Constant fold dim when the size along the index referred to is a constant.
|
2018-10-05 09:28:49 -07:00
|
|
|
auto *opType = getOperand()->getType();
|
|
|
|
int indexSize = -1;
|
|
|
|
if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
|
|
|
|
indexSize = tensorType->getShape()[getIndex()];
|
|
|
|
} else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
|
|
|
|
indexSize = memrefType->getShape()[getIndex()];
|
|
|
|
}
|
|
|
|
|
|
|
|
if (indexSize >= 0)
|
|
|
|
return IntegerAttr::get(indexSize, context);
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// DmaStartOp
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
void DmaStartOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << getOperationName() << ' ' << *getSrcMemRef() << '[';
|
|
|
|
p->printOperands(getSrcIndices());
|
|
|
|
*p << "], " << *getDstMemRef() << '[';
|
|
|
|
p->printOperands(getDstIndices());
|
|
|
|
*p << "], " << *getNumElements();
|
|
|
|
*p << ", " << *getTagMemRef() << '[';
|
|
|
|
p->printOperands(getTagIndices());
|
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getSrcMemRef()->getType();
|
|
|
|
*p << ", " << *getDstMemRef()->getType();
|
|
|
|
*p << ", " << *getTagMemRef()->getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse DmaStartOp.
|
2018-10-18 11:14:26 -07:00
|
|
|
// Ex:
|
2018-10-09 15:04:27 -07:00
|
|
|
// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
|
|
|
|
// %tag[%index] :
|
|
|
|
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
|
|
|
|
// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>,
|
|
|
|
// memref<1 x i32, (d0) -> (d0), 4>
|
|
|
|
//
|
|
|
|
bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType srcMemRefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
|
|
|
|
OpAsmParser::OperandType dstMemRefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
|
|
|
|
OpAsmParser::OperandType numElementsInfo;
|
|
|
|
OpAsmParser::OperandType tagMemrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
|
|
|
|
|
|
|
|
SmallVector<Type *, 3> types;
|
|
|
|
auto *indexType = parser->getBuilder().getIndexType();
|
|
|
|
|
|
|
|
// Parse and resolve the following list of operands:
|
|
|
|
// *) source memref followed by its indices (in square brackets).
|
|
|
|
// *) destination memref followed by its indices (in square brackets).
|
|
|
|
// *) dma size in KiB.
|
|
|
|
if (parser->parseOperand(srcMemRefInfo) ||
|
|
|
|
parser->parseOperandList(srcIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
|
|
|
|
parser->parseOperandList(dstIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
|
|
|
|
parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
|
|
|
|
parser->parseOperandList(tagIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseColonTypeList(types))
|
|
|
|
return true;
|
|
|
|
|
|
|
|
if (types.size() != 3)
|
|
|
|
return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
|
|
|
|
|
|
|
|
if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
|
|
|
|
parser->resolveOperands(srcIndexInfos, indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
|
|
|
|
parser->resolveOperands(dstIndexInfos, indexType, result->operands) ||
|
|
|
|
// size should be an index.
|
|
|
|
parser->resolveOperand(numElementsInfo, indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
|
|
|
|
// tag indices should be index.
|
|
|
|
parser->resolveOperands(tagIndexInfos, indexType, result->operands))
|
|
|
|
return true;
|
|
|
|
|
|
|
|
// Check that source/destination index list size matches associated rank.
|
|
|
|
if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
|
|
|
|
dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"memref rank not equal to indices count");
|
|
|
|
|
|
|
|
if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"tag memref rank not equal to indices count");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// DmaWaitOp
|
|
|
|
// ---------------------------------------------------------------------------
|
2018-10-18 11:14:26 -07:00
|
|
|
|
2018-10-09 15:04:27 -07:00
|
|
|
void DmaWaitOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << getOperationName() << ' ';
|
|
|
|
// Print operands.
|
|
|
|
p->printOperand(getTagMemRef());
|
|
|
|
*p << '[';
|
|
|
|
p->printOperands(getTagIndices());
|
2018-10-18 11:14:26 -07:00
|
|
|
*p << "], ";
|
|
|
|
p->printOperand(getNumElements());
|
2018-10-09 15:04:27 -07:00
|
|
|
*p << " : " << *getTagMemRef()->getType();
|
|
|
|
}
|
|
|
|
|
2018-10-18 11:14:26 -07:00
|
|
|
// Parse DmaWaitOp.
|
|
|
|
// Eg:
|
|
|
|
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
|
|
|
|
//
|
2018-10-09 15:04:27 -07:00
|
|
|
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType tagMemrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
|
|
|
Type *type;
|
|
|
|
auto *indexType = parser->getBuilder().getIndexType();
|
2018-10-18 11:14:26 -07:00
|
|
|
OpAsmParser::OperandType numElementsInfo;
|
2018-10-09 15:04:27 -07:00
|
|
|
|
2018-10-18 11:14:26 -07:00
|
|
|
// Parse tag memref, its indices, and dma size.
|
2018-10-09 15:04:27 -07:00
|
|
|
if (parser->parseOperand(tagMemrefInfo) ||
|
|
|
|
parser->parseOperandList(tagIndexInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
2018-10-18 11:14:26 -07:00
|
|
|
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
|
2018-10-09 15:04:27 -07:00
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
|
2018-10-18 11:14:26 -07:00
|
|
|
parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
|
|
|
|
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
2018-10-09 15:04:27 -07:00
|
|
|
return true;
|
|
|
|
|
|
|
|
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"tag memref rank not equal to indices count");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-08-23 09:58:23 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ExtractElementOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ExtractElementOp::build(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *aggregate,
|
|
|
|
ArrayRef<SSAValue *> indices) {
|
|
|
|
auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
|
|
|
|
result->addOperands(aggregate);
|
|
|
|
result->addOperands(indices);
|
|
|
|
result->types.push_back(aggregateType->getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
void ExtractElementOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "extract_element " << *getAggregate() << '[';
|
|
|
|
p->printOperands(getIndices());
|
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getAggregate()->getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType aggregateInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
VectorOrTensorType *type;
|
|
|
|
|
2018-10-06 17:21:53 -07:00
|
|
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
2018-08-23 09:58:23 -07:00
|
|
|
return parser->parseOperand(aggregateInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(aggregateInfo, type, result->operands) ||
|
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
|
|
|
parser->addTypeToList(type->getElementType(), result->types);
|
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool ExtractElementOp::verify() const {
|
2018-08-23 09:58:23 -07:00
|
|
|
if (getNumOperands() == 0)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("expected an aggregate to index into");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
|
|
|
auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
|
|
|
|
if (!aggregateType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("first operand must be a vector or tensor");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-10-13 20:36:03 -07:00
|
|
|
if (getType() != aggregateType->getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type must match element type of aggregate");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
|
|
|
for (auto *idx : getIndices())
|
2018-10-06 17:21:53 -07:00
|
|
|
if (!idx->getType()->isIndex())
|
|
|
|
return emitOpError("index to extract_element must have 'index' type");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
|
|
|
// Verify the # indices match if we have a ranked type.
|
2018-10-09 16:49:39 -07:00
|
|
|
auto aggregateRank = aggregateType->getRank();
|
2018-08-23 09:58:23 -07:00
|
|
|
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of indices for extract_element");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-08-23 09:58:23 -07:00
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LoadOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-23 09:58:23 -07:00
|
|
|
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
|
|
|
|
ArrayRef<SSAValue *> indices) {
|
|
|
|
auto *memrefType = cast<MemRefType>(memref->getType());
|
|
|
|
result->addOperands(memref);
|
|
|
|
result->addOperands(indices);
|
|
|
|
result->types.push_back(memrefType->getElementType());
|
|
|
|
}
|
|
|
|
|
2018-07-25 11:15:20 -07:00
|
|
|
void LoadOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "load " << *getMemRef() << '[';
|
|
|
|
p->printOperands(getIndices());
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getMemRef()->getType();
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
MemRefType *type;
|
|
|
|
|
2018-10-06 17:21:53 -07:00
|
|
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperand(memrefInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(memrefInfo, type, result->operands) ||
|
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
|
|
|
parser->addTypeToList(type->getElementType(), result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool LoadOp::verify() const {
|
2018-07-28 09:36:25 -07:00
|
|
|
if (getNumOperands() == 0)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("expected a memref to load from");
|
2018-07-24 10:13:31 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
|
|
|
if (!memRefType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("first operand must be a memref");
|
2018-07-24 10:13:31 -07:00
|
|
|
|
2018-10-13 20:36:03 -07:00
|
|
|
if (getType() != memRefType->getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("result type must match element type of memref");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
|
|
|
if (memRefType->getRank() != getNumOperands() - 1)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("incorrect number of indices for load");
|
2018-08-23 09:58:23 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
for (auto *idx : getIndices())
|
2018-10-06 17:21:53 -07:00
|
|
|
if (!idx->getType()->isIndex())
|
|
|
|
return emitOpError("index to load must have 'index' type");
|
2018-07-24 17:43:56 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
// TODO: Verify we have the right number of indices.
|
2018-07-24 17:43:56 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
|
|
|
|
// result of an affine_apply.
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-24 10:13:31 -07:00
|
|
|
}
|
|
|
|
|
2018-09-26 10:07:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MulFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "mulf takes two operands");
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
|
|
|
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
2018-09-26 10:07:16 -07:00
|
|
|
return FloatAttr::get(lhs->getValue() * rhs->getValue(), context);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MulIOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "muli takes two operands");
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
|
|
|
// 0*x == 0
|
|
|
|
if (lhs->getValue() == 0)
|
|
|
|
return lhs;
|
|
|
|
|
|
|
|
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
|
|
|
// TODO: Handle the overflow case.
|
2018-10-03 09:43:13 -07:00
|
|
|
return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context);
|
|
|
|
}
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
// x*0 == 0
|
|
|
|
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
|
|
|
if (rhs->getValue() == 0)
|
|
|
|
return rhs;
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-09-06 17:31:21 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ShapeCastOp::build(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *input, Type *resultType) {
|
|
|
|
result->addOperands(input);
|
|
|
|
result->addTypes(resultType);
|
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool ShapeCastOp::verify() const {
|
2018-09-06 17:31:21 -07:00
|
|
|
auto *opType = dyn_cast<TensorType>(getOperand()->getType());
|
2018-10-13 20:36:03 -07:00
|
|
|
auto *resType = dyn_cast<TensorType>(getType());
|
2018-09-06 17:31:21 -07:00
|
|
|
if (!opType || !resType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires input and result types to be tensors");
|
2018-09-06 17:31:21 -07:00
|
|
|
|
|
|
|
if (opType == resType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires the input and result type to be different");
|
2018-09-06 17:31:21 -07:00
|
|
|
|
|
|
|
if (opType->getElementType() != resType->getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError(
|
|
|
|
"requires input and result element types to be the same");
|
2018-09-06 17:31:21 -07:00
|
|
|
|
|
|
|
// If the source or destination are unranked, then the cast is valid.
|
|
|
|
auto *opRType = dyn_cast<RankedTensorType>(opType);
|
|
|
|
auto *resRType = dyn_cast<RankedTensorType>(resType);
|
|
|
|
if (!opRType || !resRType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-09-06 17:31:21 -07:00
|
|
|
|
|
|
|
// If they are both ranked, they have to have the same rank, and any specified
|
|
|
|
// dimensions must match.
|
|
|
|
if (opRType->getRank() != resRType->getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires input and result ranks to match");
|
2018-09-06 17:31:21 -07:00
|
|
|
|
|
|
|
for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
|
|
|
|
int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
|
|
|
|
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("requires static dimensions to match");
|
2018-09-06 17:31:21 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-09-06 17:31:21 -07:00
|
|
|
}
|
|
|
|
|
2018-09-13 09:16:32 -07:00
|
|
|
void ShapeCastOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "shape_cast " << *getOperand() << " : " << *getOperand()->getType()
|
|
|
|
<< " to " << *getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool ShapeCastOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType srcInfo;
|
|
|
|
Type *srcType, *dstType;
|
|
|
|
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
|
|
|
|
parser->resolveOperand(srcInfo, srcType, result->operands) ||
|
|
|
|
parser->parseKeywordType("to", dstType) ||
|
|
|
|
parser->addTypeToList(dstType, result->types);
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StoreOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-31 14:49:38 -07:00
|
|
|
void StoreOp::build(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *valueToStore, SSAValue *memref,
|
|
|
|
ArrayRef<SSAValue *> indices) {
|
|
|
|
result->addOperands(valueToStore);
|
|
|
|
result->addOperands(memref);
|
|
|
|
result->addOperands(indices);
|
|
|
|
}
|
|
|
|
|
2018-07-31 14:11:38 -07:00
|
|
|
void StoreOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "store " << *getValueToStore();
|
|
|
|
*p << ", " << *getMemRef() << '[';
|
|
|
|
p->printOperands(getIndices());
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getMemRef()->getType();
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-31 14:11:38 -07:00
|
|
|
OpAsmParser::OperandType storeValueInfo;
|
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
MemRefType *memrefType;
|
|
|
|
|
2018-10-06 17:21:53 -07:00
|
|
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
|
|
|
parser->parseOperand(memrefInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(memrefType) ||
|
2018-08-08 11:02:58 -07:00
|
|
|
parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
|
|
|
|
result->operands) ||
|
|
|
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
2018-08-07 09:12:35 -07:00
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-09-09 20:40:23 -07:00
|
|
|
bool StoreOp::verify() const {
|
2018-07-31 14:11:38 -07:00
|
|
|
if (getNumOperands() < 2)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("expected a value to store and a memref");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
// Second operand is a memref type.
|
|
|
|
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
|
|
|
if (!memRefType)
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("second operand must be a memref");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
// First operand must have same type as memref element type.
|
|
|
|
if (getValueToStore()->getType() != memRefType->getElementType())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("first operand must have same type memref element type");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
if (getNumOperands() != 2 + memRefType->getRank())
|
2018-09-09 20:40:23 -07:00
|
|
|
return emitOpError("store index operand count not equal to memref rank");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
for (auto *idx : getIndices())
|
2018-10-06 17:21:53 -07:00
|
|
|
if (!idx->getType()->isIndex())
|
|
|
|
return emitOpError("index to load must have 'index' type");
|
2018-07-31 14:11:38 -07:00
|
|
|
|
|
|
|
// TODO: Verify we have the right number of indices.
|
|
|
|
|
|
|
|
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
|
|
|
|
// result of an affine_apply.
|
2018-09-09 20:40:23 -07:00
|
|
|
return false;
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-10-03 09:43:13 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SubFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "subf takes two operands");
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
|
|
|
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
2018-10-03 09:43:13 -07:00
|
|
|
return FloatAttr::get(lhs->getValue() - rhs->getValue(), context);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SubIOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
assert(operands.size() == 2 && "subi takes two operands");
|
|
|
|
|
2018-10-05 09:28:49 -07:00
|
|
|
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
|
|
|
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
2018-10-03 09:43:13 -07:00
|
|
|
return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|