2018-07-05 09:12:11 -07:00
|
|
|
//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/IR/StandardOps.h"
|
2018-07-24 10:13:31 -07:00
|
|
|
#include "mlir/IR/AffineMap.h"
|
2018-07-25 11:15:20 -07:00
|
|
|
#include "mlir/IR/Builders.h"
|
2018-07-24 16:07:22 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2018-07-05 09:12:11 -07:00
|
|
|
#include "mlir/IR/OperationSet.h"
|
2018-07-24 08:34:58 -07:00
|
|
|
#include "mlir/IR/SSAValue.h"
|
|
|
|
#include "mlir/IR/Types.h"
|
2018-08-09 12:28:58 -07:00
|
|
|
#include "mlir/Support/STLExtras.h"
|
2018-07-05 09:12:11 -07:00
|
|
|
#include "llvm/Support/raw_ostream.h"
|
2018-08-09 12:28:58 -07:00
|
|
|
|
2018-07-05 09:12:11 -07:00
|
|
|
using namespace mlir;
|
|
|
|
|
2018-07-30 13:08:05 -07:00
|
|
|
static void printDimAndSymbolList(Operation::const_operand_iterator begin,
|
|
|
|
Operation::const_operand_iterator end,
|
|
|
|
unsigned numDims, OpAsmPrinter *p) {
|
|
|
|
*p << '(';
|
|
|
|
p->printOperands(begin, begin + numDims);
|
|
|
|
*p << ')';
|
|
|
|
|
|
|
|
if (begin + numDims != end) {
|
|
|
|
*p << '[';
|
|
|
|
p->printOperands(begin + numDims, end);
|
|
|
|
*p << ']';
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parses dimension and symbol list, and sets 'numDims' to the number of
|
|
|
|
// dimension operands parsed.
|
|
|
|
// Returns 'false' on success and 'true' on error.
|
|
|
|
static bool
|
|
|
|
parseDimAndSymbolList(OpAsmParser *parser,
|
|
|
|
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
|
2018-08-07 09:12:35 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 8> opInfos;
|
2018-08-02 16:54:36 -07:00
|
|
|
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
|
2018-07-30 13:08:05 -07:00
|
|
|
return true;
|
|
|
|
// Store number of dimensions for validation by caller.
|
|
|
|
numDims = opInfos.size();
|
|
|
|
|
|
|
|
// Parse the optional symbol operands.
|
|
|
|
auto *affineIntTy = parser->getBuilder().getAffineIntType();
|
2018-08-02 16:54:36 -07:00
|
|
|
if (parser->parseOperandList(opInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::OptionalSquare) ||
|
2018-07-30 13:08:05 -07:00
|
|
|
parser->resolveOperands(opInfos, affineIntTy, operands))
|
|
|
|
return true;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AddFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 2> ops;
|
|
|
|
Type *type;
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperandList(ops, 2) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperands(ops, type, result->operands) ||
|
|
|
|
parser->addTypeToList(type, result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-07-24 16:07:22 -07:00
|
|
|
void AddFOp::print(OpAsmPrinter *p) const {
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << "addf " << *getOperand(0) << ", " << *getOperand(1);
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getType();
|
2018-07-05 09:12:11 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
// TODO: Have verify functions return std::string to enable more descriptive
|
|
|
|
// error messages.
|
2018-07-06 10:46:19 -07:00
|
|
|
// Return an error message on failure.
|
|
|
|
const char *AddFOp::verify() const {
|
|
|
|
// TODO: Check that the types of the LHS and RHS match.
|
|
|
|
// TODO: This should be a refinement of TwoOperands.
|
|
|
|
// TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AffineApplyOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-28 09:36:25 -07:00
|
|
|
auto &builder = parser->getBuilder();
|
|
|
|
auto *affineIntTy = builder.getAffineIntType();
|
|
|
|
|
|
|
|
AffineMapAttr *mapAttr;
|
2018-07-30 13:08:05 -07:00
|
|
|
unsigned numDims;
|
2018-08-07 09:12:35 -07:00
|
|
|
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
|
|
|
|
parseDimAndSymbolList(parser, result->operands, numDims) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes))
|
|
|
|
return true;
|
2018-07-28 09:36:25 -07:00
|
|
|
auto *map = mapAttr->getValue();
|
2018-07-30 13:08:05 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
if (map->getNumDims() != numDims ||
|
2018-08-07 09:12:35 -07:00
|
|
|
numDims + map->getNumSymbols() != result->operands.size()) {
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"dimension or symbol index mismatch");
|
2018-07-28 09:36:25 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
result->types.append(map->getNumResults(), affineIntTy);
|
|
|
|
return false;
|
2018-07-28 09:36:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
void AffineApplyOp::print(OpAsmPrinter *p) const {
|
|
|
|
auto *map = getAffineMap();
|
|
|
|
*p << "affine_apply " << *map;
|
2018-07-30 13:08:05 -07:00
|
|
|
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
|
2018-08-02 16:54:36 -07:00
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
2018-07-28 09:36:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
const char *AffineApplyOp::verify() const {
|
|
|
|
// Check that affine map attribute was specified.
|
|
|
|
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
|
|
|
if (!affineMapAttr)
|
|
|
|
return "requires an affine map.";
|
|
|
|
|
|
|
|
// Check input and output dimensions match.
|
|
|
|
auto *map = affineMapAttr->getValue();
|
|
|
|
|
|
|
|
// Verify that operand count matches affine map dimension and symbol count.
|
|
|
|
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
|
|
|
|
return "operand count and affine map dimension and symbol count must match";
|
|
|
|
|
|
|
|
// Verify that result count matches affine map result count.
|
|
|
|
if (getNumResults() != map->getNumResults())
|
|
|
|
return "result count and affine map result count must match";
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AllocOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-07-30 13:08:05 -07:00
|
|
|
void AllocOp::print(OpAsmPrinter *p) const {
|
|
|
|
MemRefType *type = cast<MemRefType>(getMemRef()->getType());
|
|
|
|
*p << "alloc";
|
|
|
|
// Print dynamic dimension operands.
|
|
|
|
printDimAndSymbolList(operand_begin(), operand_end(),
|
|
|
|
type->getNumDynamicDims(), p);
|
2018-08-02 16:54:36 -07:00
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
2018-07-30 13:08:05 -07:00
|
|
|
*p << " : " << *type;
|
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-30 13:08:05 -07:00
|
|
|
MemRefType *type;
|
|
|
|
|
2018-07-31 16:21:36 -07:00
|
|
|
// Parse the dimension operands and optional symbol operands, followed by a
|
|
|
|
// memref type.
|
2018-07-30 13:08:05 -07:00
|
|
|
unsigned numDimOperands;
|
2018-08-07 09:12:35 -07:00
|
|
|
if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type))
|
|
|
|
return true;
|
2018-07-30 13:08:05 -07:00
|
|
|
|
|
|
|
// Check numDynamicDims against number of question marks in memref type.
|
|
|
|
if (numDimOperands != type->getNumDynamicDims()) {
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"dimension operand count does not equal memref "
|
|
|
|
"dynamic dimension count");
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Check that the number of symbol operands matches the number of symbols in
|
|
|
|
// the first affinemap of the memref's affine map composition.
|
|
|
|
// Note that a memref must specify at least one affine map in the composition.
|
2018-08-07 09:12:35 -07:00
|
|
|
if (result->operands.size() - numDimOperands !=
|
2018-07-30 13:08:05 -07:00
|
|
|
type->getAffineMaps()[0]->getNumSymbols()) {
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->emitError(
|
|
|
|
parser->getNameLoc(),
|
|
|
|
"affine map symbol operand count does not equal memref affine map "
|
|
|
|
"symbol count");
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
result->types.push_back(type);
|
|
|
|
return false;
|
2018-07-30 13:08:05 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
const char *AllocOp::verify() const {
|
|
|
|
// TODO(andydavis): Verify alloc.
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-01 10:43:18 -07:00
|
|
|
void ConstantOp::print(OpAsmPrinter *p) const {
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << "constant " << *getValue();
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
|
|
|
*p << " : " << *getType();
|
2018-08-01 10:43:18 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-08-01 10:43:18 -07:00
|
|
|
Attribute *valueAttr;
|
|
|
|
Type *type;
|
|
|
|
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->addTypeToList(type, result->types);
|
2018-08-01 10:43:18 -07:00
|
|
|
}
|
|
|
|
|
2018-07-24 08:34:58 -07:00
|
|
|
/// The constant op requires an attribute, and furthermore requires that it
|
|
|
|
/// matches the return type.
|
|
|
|
const char *ConstantOp::verify() const {
|
|
|
|
auto *value = getValue();
|
|
|
|
if (!value)
|
|
|
|
return "requires a 'value' attribute";
|
|
|
|
|
|
|
|
auto *type = this->getType();
|
2018-07-24 10:41:30 -07:00
|
|
|
if (isa<IntegerType>(type) || type->isAffineInt()) {
|
2018-07-24 08:34:58 -07:00
|
|
|
if (!isa<IntegerAttr>(value))
|
|
|
|
return "requires 'value' to be an integer for an integer result type";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isa<FunctionType>(type)) {
|
|
|
|
// TODO: Verify a function attr.
|
|
|
|
}
|
|
|
|
|
|
|
|
return "requires a result type that aligns with the 'value' attribute";
|
|
|
|
}
|
|
|
|
|
2018-08-07 12:02:37 -07:00
|
|
|
/// ConstantIntOp only matches values whose result type is an IntegerType.
|
2018-07-24 08:34:58 -07:00
|
|
|
bool ConstantIntOp::isClassFor(const Operation *op) {
|
|
|
|
return ConstantOp::isClassFor(op) &&
|
2018-08-07 12:02:37 -07:00
|
|
|
isa<IntegerType>(op->getResult(0)->getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
OperationState ConstantIntOp::build(Builder *builder, int64_t value,
|
|
|
|
unsigned width) {
|
|
|
|
OperationState result(builder->getIdentifier("constant"));
|
|
|
|
result.attributes.push_back(
|
|
|
|
{builder->getIdentifier("value"), builder->getIntegerAttr(value)});
|
|
|
|
result.types.push_back(builder->getIntegerType(width));
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// ConstantAffineIntOp only matches values whose result type is AffineInt.
|
|
|
|
bool ConstantAffineIntOp::isClassFor(const Operation *op) {
|
|
|
|
return ConstantOp::isClassFor(op) &&
|
|
|
|
op->getResult(0)->getType()->isAffineInt();
|
|
|
|
}
|
|
|
|
|
|
|
|
OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
|
|
|
|
OperationState result(builder->getIdentifier("constant"));
|
|
|
|
result.attributes.push_back(
|
|
|
|
{builder->getIdentifier("value"), builder->getIntegerAttr(value)});
|
|
|
|
result.types.push_back(builder->getAffineIntType());
|
|
|
|
return result;
|
2018-07-24 08:34:58 -07:00
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DimOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-07-24 16:07:22 -07:00
|
|
|
void DimOp::print(OpAsmPrinter *p) const {
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << "dim " << *getOperand() << ", " << getIndex();
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
|
|
|
|
*p << " : " << *getOperand()->getType();
|
2018-07-05 09:12:11 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
OpAsmParser::OperandType operandInfo;
|
|
|
|
IntegerAttr *indexAttr;
|
|
|
|
Type *type;
|
2018-08-02 16:54:36 -07:00
|
|
|
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
|
|
|
parser->parseAttribute(indexAttr, "index", result->attributes) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(operandInfo, type, result->operands) ||
|
|
|
|
parser->addTypeToList(parser->getBuilder().getAffineIntType(),
|
|
|
|
result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-07-06 10:46:19 -07:00
|
|
|
const char *DimOp::verify() const {
|
|
|
|
// Check that we have an integer index operand.
|
|
|
|
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
|
|
|
if (!indexAttr)
|
2018-07-24 08:34:58 -07:00
|
|
|
return "requires an integer attribute named 'index'";
|
|
|
|
uint64_t index = (uint64_t)indexAttr->getValue();
|
|
|
|
|
|
|
|
auto *type = getOperand()->getType();
|
|
|
|
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
|
|
|
if (index >= tensorType->getRank())
|
|
|
|
return "index is out of range";
|
|
|
|
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
|
|
|
|
if (index >= memrefType->getRank())
|
|
|
|
return "index is out of range";
|
2018-07-06 10:46:19 -07:00
|
|
|
|
2018-07-24 08:34:58 -07:00
|
|
|
} else if (isa<UnrankedTensorType>(type)) {
|
|
|
|
// ok, assumed to be in-range.
|
|
|
|
} else {
|
|
|
|
return "requires an operand with tensor or memref type";
|
|
|
|
}
|
2018-07-06 10:46:19 -07:00
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LoadOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-07-25 11:15:20 -07:00
|
|
|
void LoadOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "load " << *getMemRef() << '[';
|
|
|
|
p->printOperands(getIndices());
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getMemRef()->getType();
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-25 11:15:20 -07:00
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
MemRefType *type;
|
|
|
|
|
|
|
|
auto affineIntTy = parser->getBuilder().getAffineIntType();
|
2018-08-08 11:02:58 -07:00
|
|
|
return parser->parseOperand(memrefInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperand(memrefInfo, type, result->operands) ||
|
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
|
|
|
parser->addTypeToList(type->getElementType(), result->types);
|
2018-07-25 11:15:20 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
const char *LoadOp::verify() const {
|
2018-07-28 09:36:25 -07:00
|
|
|
if (getNumOperands() == 0)
|
|
|
|
return "expected a memref to load from";
|
2018-07-24 10:13:31 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
|
|
|
if (!memRefType)
|
|
|
|
return "first operand must be a memref";
|
2018-07-24 10:13:31 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
for (auto *idx : getIndices())
|
|
|
|
if (!idx->getType()->isAffineInt())
|
|
|
|
return "index to load must have 'affineint' type";
|
2018-07-24 17:43:56 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
// TODO: Verify we have the right number of indices.
|
2018-07-24 17:43:56 -07:00
|
|
|
|
2018-07-28 09:36:25 -07:00
|
|
|
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
|
|
|
|
// result of an affine_apply.
|
2018-07-24 10:13:31 -07:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ReturnOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
|
|
|
SmallVector<Type *, 2> types;
|
|
|
|
|
|
|
|
return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
|
|
|
|
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
|
|
|
parser->resolveOperands(opInfo, types, result->operands);
|
|
|
|
}
|
|
|
|
|
|
|
|
void ReturnOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "return";
|
|
|
|
if (getNumOperands() > 0) {
|
|
|
|
*p << " ";
|
|
|
|
p->printOperands(operand_begin(), operand_end());
|
|
|
|
*p << " : ";
|
|
|
|
interleave(operand_begin(), operand_end(),
|
|
|
|
[&](auto *e) { p->printType(e->getType()); },
|
|
|
|
[&]() { *p << ", "; });
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const char *ReturnOp::verify() const {
|
|
|
|
// ReturnOp must be part of an ML function.
|
|
|
|
if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
|
2018-08-09 23:21:19 -07:00
|
|
|
MLFunction *func = dyn_cast_or_null<MLFunction>(stmt->getBlock());
|
|
|
|
if (!func || &func->back() != stmt)
|
2018-08-09 12:28:58 -07:00
|
|
|
return "must be the last statement in the ML function";
|
|
|
|
|
|
|
|
// Return success. Checking that operand types match those in the function
|
|
|
|
// signature is performed in the ML function verifier.
|
|
|
|
return nullptr;
|
|
|
|
}
|
2018-08-09 23:21:19 -07:00
|
|
|
return "cannot occur in a CFG function";
|
2018-08-09 12:28:58 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StoreOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-07-31 14:11:38 -07:00
|
|
|
void StoreOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "store " << *getValueToStore();
|
|
|
|
*p << ", " << *getMemRef() << '[';
|
|
|
|
p->printOperands(getIndices());
|
2018-08-02 16:54:36 -07:00
|
|
|
*p << ']';
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
*p << " : " << *getMemRef()->getType();
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
2018-08-07 09:12:35 -07:00
|
|
|
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-07-31 14:11:38 -07:00
|
|
|
OpAsmParser::OperandType storeValueInfo;
|
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
MemRefType *memrefType;
|
|
|
|
|
|
|
|
auto affineIntTy = parser->getBuilder().getAffineIntType();
|
2018-08-07 09:12:35 -07:00
|
|
|
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
|
|
|
parser->parseOperand(memrefInfo) ||
|
|
|
|
parser->parseOperandList(indexInfo, -1,
|
|
|
|
OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(memrefType) ||
|
2018-08-08 11:02:58 -07:00
|
|
|
parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
|
|
|
|
result->operands) ||
|
|
|
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
2018-08-07 09:12:35 -07:00
|
|
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
|
2018-07-31 14:11:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
const char *StoreOp::verify() const {
|
|
|
|
if (getNumOperands() < 2)
|
|
|
|
return "expected a value to store and a memref";
|
|
|
|
|
|
|
|
// Second operand is a memref type.
|
|
|
|
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
|
|
|
if (!memRefType)
|
|
|
|
return "second operand must be a memref";
|
|
|
|
|
|
|
|
// First operand must have same type as memref element type.
|
|
|
|
if (getValueToStore()->getType() != memRefType->getElementType())
|
|
|
|
return "first operand must have same type memref element type ";
|
|
|
|
|
|
|
|
if (getNumOperands() != 2 + memRefType->getRank())
|
|
|
|
return "store index operand count not equal to memref rank";
|
|
|
|
|
|
|
|
for (auto *idx : getIndices())
|
|
|
|
if (!idx->getType()->isAffineInt())
|
|
|
|
return "index to load must have 'affineint' type";
|
|
|
|
|
|
|
|
// TODO: Verify we have the right number of indices.
|
|
|
|
|
|
|
|
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
|
|
|
|
// result of an affine_apply.
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2018-08-09 12:28:58 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Register operations.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-07-05 09:12:11 -07:00
|
|
|
/// Install the standard operations in the specified operation set.
|
|
|
|
void mlir::registerStandardOperations(OperationSet &opSet) {
|
2018-07-31 14:11:38 -07:00
|
|
|
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
|
2018-08-09 23:21:19 -07:00
|
|
|
ReturnOp, StoreOp>(
|
2018-07-31 14:11:38 -07:00
|
|
|
/*prefix=*/"");
|
2018-07-05 09:12:11 -07:00
|
|
|
}
|