2020-09-28 22:47:00 -07:00
|
|
|
//===- Async.cpp - MLIR Async Operations ----------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Async/IR/Async.h"
|
|
|
|
|
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
2022-11-02 11:27:26 -07:00
|
|
|
#include "mlir/IR/FunctionImplementation.h"
|
2023-01-08 14:15:07 -08:00
|
|
|
#include "mlir/IR/IRMapping.h"
|
2022-11-02 11:27:26 -07:00
|
|
|
#include "llvm/ADT/MapVector.h"
|
2020-09-28 22:47:00 -07:00
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
|
2020-10-08 13:28:09 -07:00
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::async;
|
2020-09-28 22:47:00 -07:00
|
|
|
|
2021-06-28 22:54:11 +00:00
|
|
|
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
|
|
|
|
|
2021-07-28 15:25:00 -07:00
|
|
|
constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
|
|
|
|
|
2020-09-28 22:47:00 -07:00
|
|
|
void AsyncDialect::initialize() {
|
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
|
|
|
>();
|
2021-01-25 14:14:12 -08:00
|
|
|
addTypes<
|
|
|
|
#define GET_TYPEDEF_LIST
|
|
|
|
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
|
|
|
|
>();
|
2020-09-29 13:55:33 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// ExecuteOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-08-09 12:47:13 -07:00
|
|
|
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
|
2020-10-08 13:28:09 -07:00
|
|
|
|
2023-08-30 09:22:34 +02:00
|
|
|
OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
|
|
|
assert(point == getBodyRegion() && "invalid region index");
|
2022-09-29 19:00:10 -07:00
|
|
|
return getBodyOperands();
|
2021-08-19 16:28:16 +03:00
|
|
|
}
|
|
|
|
|
[mlir] Region/BranchOpInterface: Allow implicit type conversions along control-flow edges
RegionBranchOpInterface and BranchOpInterface are allowed to make implicit type conversions along control-flow edges. In effect, this adds an interface method, `areTypesCompatible`, to both interfaces, which should return whether the types of corresponding successor operands and block arguments are compatible. Users of the interfaces, here on forth, must be aware that types may mismatch, although current users (in MLIR core), are not affected by this change. By default, type equality is used.
`async.execute` already has unequal types along control-flow edges (`!async.value<f32>` vs. `f32`), but it opted out of calling `RegionBranchOpInterface::verifyTypes` in its verifier. That method has now been removed and `RegionBranchOpInterface` will verify types along control edges by default in its verifier.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D120790
2022-03-04 20:23:24 +00:00
|
|
|
bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
|
|
|
|
const auto getValueOrTokenType = [](Type type) {
|
2023-05-11 11:10:46 +02:00
|
|
|
if (auto value = llvm::dyn_cast<ValueType>(type))
|
[mlir] Region/BranchOpInterface: Allow implicit type conversions along control-flow edges
RegionBranchOpInterface and BranchOpInterface are allowed to make implicit type conversions along control-flow edges. In effect, this adds an interface method, `areTypesCompatible`, to both interfaces, which should return whether the types of corresponding successor operands and block arguments are compatible. Users of the interfaces, here on forth, must be aware that types may mismatch, although current users (in MLIR core), are not affected by this change. By default, type equality is used.
`async.execute` already has unequal types along control-flow edges (`!async.value<f32>` vs. `f32`), but it opted out of calling `RegionBranchOpInterface::verifyTypes` in its verifier. That method has now been removed and `RegionBranchOpInterface` will verify types along control edges by default in its verifier.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D120790
2022-03-04 20:23:24 +00:00
|
|
|
return value.getValueType();
|
|
|
|
return type;
|
|
|
|
};
|
|
|
|
return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
|
|
|
|
}
|
|
|
|
|
2023-08-30 09:22:34 +02:00
|
|
|
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
|
2020-11-11 01:38:51 -08:00
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
// The `body` region branch back to the parent operation.
|
2023-08-30 09:22:34 +02:00
|
|
|
if (point == getBodyRegion()) {
|
2022-09-29 19:00:10 -07:00
|
|
|
regions.push_back(RegionSuccessor(getBodyResults()));
|
2020-11-11 01:38:51 -08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise the successor is the body region.
|
2022-09-21 13:08:26 -07:00
|
|
|
regions.push_back(
|
2022-09-29 19:00:10 -07:00
|
|
|
RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
|
2020-11-11 01:38:51 -08:00
|
|
|
}
|
|
|
|
|
2020-11-13 03:01:52 -08:00
|
|
|
void ExecuteOp::build(OpBuilder &builder, OperationState &result,
|
|
|
|
TypeRange resultTypes, ValueRange dependencies,
|
|
|
|
ValueRange operands, BodyBuilderFn bodyBuilder) {
|
|
|
|
|
|
|
|
result.addOperands(dependencies);
|
|
|
|
result.addOperands(operands);
|
|
|
|
|
2023-08-09 12:47:13 -07:00
|
|
|
// Add derived `operandSegmentSizes` attribute based on parsed operands.
|
2020-11-13 03:01:52 -08:00
|
|
|
int32_t numDependencies = dependencies.size();
|
|
|
|
int32_t numOperands = operands.size();
|
2022-08-12 15:43:03 -04:00
|
|
|
auto operandSegmentSizes =
|
|
|
|
builder.getDenseI32ArrayAttr({numDependencies, numOperands});
|
2020-11-13 03:01:52 -08:00
|
|
|
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
|
|
|
|
|
|
|
// First result is always a token, and then `resultTypes` wrapped into
|
|
|
|
// `async.value`.
|
|
|
|
result.addTypes({TokenType::get(result.getContext())});
|
|
|
|
for (Type type : resultTypes)
|
|
|
|
result.addTypes(ValueType::get(type));
|
|
|
|
|
|
|
|
// Add a body region with block arguments as unwrapped async value operands.
|
|
|
|
Region *bodyRegion = result.addRegion();
|
|
|
|
bodyRegion->push_back(new Block);
|
|
|
|
Block &bodyBlock = bodyRegion->front();
|
|
|
|
for (Value operand : operands) {
|
2023-05-11 11:10:46 +02:00
|
|
|
auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
|
2020-11-13 03:01:52 -08:00
|
|
|
bodyBlock.addArgument(valueType ? valueType.getValueType()
|
2022-01-18 18:28:51 -08:00
|
|
|
: operand.getType(),
|
|
|
|
operand.getLoc());
|
2020-11-13 03:01:52 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Create the default terminator if the builder is not provided and if the
|
|
|
|
// expected result is empty. Otherwise, leave this to the caller
|
|
|
|
// because we don't know which values to return from the execute op.
|
|
|
|
if (resultTypes.empty() && !bodyBuilder) {
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
builder.setInsertionPointToStart(&bodyBlock);
|
|
|
|
builder.create<async::YieldOp>(result.location, ValueRange());
|
|
|
|
} else if (bodyBuilder) {
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
builder.setInsertionPointToStart(&bodyBlock);
|
|
|
|
bodyBuilder(builder, result.location, bodyBlock.getArguments());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void ExecuteOp::print(OpAsmPrinter &p) {
|
2020-10-08 13:28:09 -07:00
|
|
|
// [%tokens,...]
|
2022-09-29 19:00:10 -07:00
|
|
|
if (!getDependencies().empty())
|
|
|
|
p << " [" << getDependencies() << "]";
|
2020-10-08 13:28:09 -07:00
|
|
|
|
|
|
|
// (%value as %unwrapped: !async.value<!arg.type>, ...)
|
2022-09-29 19:00:10 -07:00
|
|
|
if (!getBodyOperands().empty()) {
|
2020-10-08 13:28:09 -07:00
|
|
|
p << " (";
|
2022-09-29 19:00:10 -07:00
|
|
|
Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
|
|
|
|
llvm::interleaveComma(
|
|
|
|
getBodyOperands(), p, [&, n = 0](Value operand) mutable {
|
|
|
|
Value argument = entry ? entry->getArgument(n++) : Value();
|
|
|
|
p << operand << " as " << argument << ": " << operand.getType();
|
|
|
|
});
|
2020-10-08 13:28:09 -07:00
|
|
|
p << ")";
|
|
|
|
}
|
|
|
|
|
|
|
|
// -> (!async.value<!return.type>, ...)
|
2022-02-07 17:54:04 -08:00
|
|
|
p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
|
|
|
|
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
|
2021-02-25 19:35:53 +01:00
|
|
|
{kOperandSegmentSizesAttr});
|
2022-01-18 07:47:25 +00:00
|
|
|
p << ' ';
|
2022-09-29 19:00:10 -07:00
|
|
|
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
|
2020-09-29 13:55:33 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
|
2020-09-29 13:55:33 -07:00
|
|
|
MLIRContext *ctx = result.getContext();
|
|
|
|
|
2020-10-08 13:28:09 -07:00
|
|
|
// Sizes of parsed variadic operands, will be updated below after parsing.
|
|
|
|
int32_t numDependencies = 0;
|
|
|
|
|
|
|
|
auto tokenTy = TokenType::get(ctx);
|
|
|
|
|
|
|
|
// Parse dependency tokens.
|
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
|
2020-10-08 13:28:09 -07:00
|
|
|
if (parser.parseOperandList(tokenArgs) ||
|
|
|
|
parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
|
|
|
|
parser.parseRSquare())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
numDependencies = tokenArgs.size();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse async value operands (%value as %unwrapped : !async.value<!type>).
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
|
2022-04-28 17:26:43 -07:00
|
|
|
SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
|
2020-10-08 13:28:09 -07:00
|
|
|
SmallVector<Type, 4> valueTypes;
|
|
|
|
|
2021-09-20 18:27:40 -07:00
|
|
|
// Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
|
|
|
|
auto parseAsyncValueArg = [&]() -> ParseResult {
|
|
|
|
if (parser.parseOperand(valueArgs.emplace_back()) ||
|
|
|
|
parser.parseKeyword("as") ||
|
2022-04-28 17:26:43 -07:00
|
|
|
parser.parseArgument(unwrappedArgs.emplace_back()) ||
|
2021-09-20 18:27:40 -07:00
|
|
|
parser.parseColonType(valueTypes.emplace_back()))
|
|
|
|
return failure();
|
|
|
|
|
2023-05-11 11:10:46 +02:00
|
|
|
auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
|
2022-04-28 17:26:43 -07:00
|
|
|
unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
|
2021-09-20 18:27:40 -07:00
|
|
|
return success();
|
|
|
|
};
|
|
|
|
|
|
|
|
auto argsLoc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
|
|
|
|
parseAsyncValueArg) ||
|
|
|
|
parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
int32_t numOperands = valueArgs.size();
|
2020-10-08 13:28:09 -07:00
|
|
|
|
2023-08-09 12:47:13 -07:00
|
|
|
// Add derived `operandSegmentSizes` attribute based on parsed operands.
|
2022-08-12 15:43:03 -04:00
|
|
|
auto operandSegmentSizes =
|
|
|
|
parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
|
2020-10-08 13:28:09 -07:00
|
|
|
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
|
|
|
|
|
|
|
// Parse the types of results returned from the async execute op.
|
|
|
|
SmallVector<Type, 4> resultTypes;
|
2020-09-29 13:55:33 -07:00
|
|
|
NamedAttrList attrs;
|
2022-05-13 15:38:50 +01:00
|
|
|
if (parser.parseOptionalArrowTypeList(resultTypes) ||
|
|
|
|
// Async execute first result is always a completion token.
|
|
|
|
parser.addTypeToList(tokenTy, result.types) ||
|
|
|
|
parser.addTypesToList(resultTypes, result.types) ||
|
|
|
|
// Parse operation attributes.
|
|
|
|
parser.parseOptionalAttrDictWithKeyword(attrs))
|
2020-09-29 13:55:33 -07:00
|
|
|
return failure();
|
2022-05-13 15:38:50 +01:00
|
|
|
|
2020-09-29 13:55:33 -07:00
|
|
|
result.addAttributes(attrs);
|
|
|
|
|
2020-10-08 13:28:09 -07:00
|
|
|
// Parse asynchronous region.
|
|
|
|
Region *body = result.addRegion();
|
2022-04-28 17:26:43 -07:00
|
|
|
return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
|
2020-09-29 13:55:33 -07:00
|
|
|
}
|
|
|
|
|
2022-03-10 22:10:45 +00:00
|
|
|
LogicalResult ExecuteOp::verifyRegions() {
|
2020-10-08 13:28:09 -07:00
|
|
|
// Unwrap async.execute value operands types.
|
2022-09-29 19:00:10 -07:00
|
|
|
auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
|
2023-05-11 11:10:46 +02:00
|
|
|
return llvm::cast<ValueType>(operand.getType()).getValueType();
|
2020-10-08 13:28:09 -07:00
|
|
|
});
|
|
|
|
|
|
|
|
// Verify that unwrapped argument types matches the body region arguments.
|
2022-09-29 19:00:10 -07:00
|
|
|
if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
|
2022-02-02 10:24:48 -08:00
|
|
|
return emitOpError("async body region argument types do not match the "
|
|
|
|
"execute operation arguments types");
|
2020-10-08 13:28:09 -07:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
2020-09-29 13:55:33 -07:00
|
|
|
|
2021-06-27 17:44:31 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// CreateGroupOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
// Find all `await_all` users of the group.
|
|
|
|
llvm::SmallVector<AwaitAllOp> awaitAllUsers;
|
|
|
|
|
|
|
|
auto isAwaitAll = [&](Operation *op) -> bool {
|
|
|
|
if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
|
|
|
|
awaitAllUsers.push_back(awaitAll);
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
};
|
|
|
|
|
|
|
|
// Check if all users of the group are `await_all` operations.
|
|
|
|
if (!llvm::all_of(op->getUsers(), isAwaitAll))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// If group is only awaited without adding anything to it, we can safely erase
|
|
|
|
// the create operation and all users.
|
|
|
|
for (AwaitAllOp awaitAll : awaitAllUsers)
|
|
|
|
rewriter.eraseOp(awaitAll);
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-10-12 14:38:42 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// AwaitOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
|
|
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
|
|
result.addOperands({operand});
|
|
|
|
result.attributes.append(attrs.begin(), attrs.end());
|
|
|
|
|
|
|
|
// Add unwrapped async.value type to the returned values types.
|
2023-05-11 11:10:46 +02:00
|
|
|
if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
|
2020-10-12 14:38:42 -07:00
|
|
|
result.addTypes(valueType.getValueType());
|
|
|
|
}
|
|
|
|
|
|
|
|
static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
|
|
|
|
Type &resultType) {
|
|
|
|
if (parser.parseType(operandType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Add unwrapped async.value type to the returned values types.
|
2023-05-11 11:10:46 +02:00
|
|
|
if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
|
2020-10-12 14:38:42 -07:00
|
|
|
resultType = valueType.getValueType();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-10-28 01:01:44 +00:00
|
|
|
static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
|
|
|
|
Type operandType, Type resultType) {
|
2020-10-12 14:38:42 -07:00
|
|
|
p << operandType;
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:24:48 -08:00
|
|
|
LogicalResult AwaitOp::verify() {
|
2022-09-29 19:00:10 -07:00
|
|
|
Type argType = getOperand().getType();
|
2020-10-12 14:38:42 -07:00
|
|
|
|
|
|
|
// Awaiting on a token does not have any results.
|
2023-05-11 11:10:46 +02:00
|
|
|
if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
|
2022-02-02 10:24:48 -08:00
|
|
|
return emitOpError("awaiting on a token must have empty result");
|
2020-10-12 14:38:42 -07:00
|
|
|
|
|
|
|
// Awaiting on a value unwraps the async value type.
|
2023-05-11 11:10:46 +02:00
|
|
|
if (auto value = llvm::dyn_cast<ValueType>(argType)) {
|
2022-02-02 10:24:48 -08:00
|
|
|
if (*getResultType() != value.getValueType())
|
|
|
|
return emitOpError() << "result type " << *getResultType()
|
|
|
|
<< " does not match async value type "
|
|
|
|
<< value.getValueType();
|
2020-10-12 14:38:42 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-11-02 11:27:26 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// FuncOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
|
|
|
|
FunctionType type, ArrayRef<NamedAttribute> attrs,
|
|
|
|
ArrayRef<DictionaryAttr> argAttrs) {
|
|
|
|
state.addAttribute(SymbolTable::getSymbolAttrName(),
|
|
|
|
builder.getStringAttr(name));
|
2022-12-06 11:28:47 -08:00
|
|
|
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
|
2022-11-02 11:27:26 -07:00
|
|
|
|
|
|
|
state.attributes.append(attrs.begin(), attrs.end());
|
|
|
|
state.addRegion();
|
|
|
|
|
|
|
|
if (argAttrs.empty())
|
|
|
|
return;
|
|
|
|
assert(type.getNumInputs() == argAttrs.size());
|
2022-12-06 11:28:47 -08:00
|
|
|
function_interface_impl::addArgAndResultAttrs(
|
|
|
|
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
|
|
|
|
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
|
2022-11-02 11:27:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
|
|
auto buildFuncType =
|
|
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
|
|
function_interface_impl::VariadicFlag,
|
|
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
2022-12-06 11:28:47 -08:00
|
|
|
parser, result, /*allowVariadic=*/false,
|
|
|
|
getFunctionTypeAttrName(result.name), buildFuncType,
|
|
|
|
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
2022-11-02 11:27:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
void FuncOp::print(OpAsmPrinter &p) {
|
2022-12-06 11:28:47 -08:00
|
|
|
function_interface_impl::printFunctionOp(
|
|
|
|
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
|
|
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
2022-11-02 11:27:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Check that the result type of async.func is not void and must be
|
|
|
|
/// some async token or async values.
|
|
|
|
LogicalResult FuncOp::verify() {
|
|
|
|
auto resultTypes = getResultTypes();
|
|
|
|
if (resultTypes.empty())
|
|
|
|
return emitOpError()
|
|
|
|
<< "result is expected to be at least of size 1, but got "
|
|
|
|
<< resultTypes.size();
|
|
|
|
|
|
|
|
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
|
|
|
|
auto type = resultTypes[i];
|
2023-05-11 11:10:46 +02:00
|
|
|
if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
|
2022-11-02 11:27:26 -07:00
|
|
|
return emitOpError() << "result type must be async value type or async "
|
|
|
|
"token type, but got "
|
|
|
|
<< type;
|
|
|
|
// We only allow AsyncToken appear as the first return value
|
2023-05-11 11:10:46 +02:00
|
|
|
if (llvm::isa<TokenType>(type) && i != 0) {
|
2022-11-02 11:27:26 -07:00
|
|
|
return emitOpError()
|
|
|
|
<< " results' (optional) async token type is expected "
|
|
|
|
"to appear as the 1st return value, but got "
|
|
|
|
<< i + 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// CallOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
|
|
// Check that the callee attribute was specified.
|
|
|
|
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
|
|
|
|
if (!fnAttr)
|
|
|
|
return emitOpError("requires a 'callee' symbol reference attribute");
|
|
|
|
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
|
|
|
|
if (!fn)
|
|
|
|
return emitOpError() << "'" << fnAttr.getValue()
|
|
|
|
<< "' does not reference a valid async function";
|
|
|
|
|
|
|
|
// Verify that the operand and result types match the callee.
|
|
|
|
auto fnType = fn.getFunctionType();
|
|
|
|
if (fnType.getNumInputs() != getNumOperands())
|
|
|
|
return emitOpError("incorrect number of operands for callee");
|
|
|
|
|
|
|
|
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
|
|
|
if (getOperand(i).getType() != fnType.getInput(i))
|
|
|
|
return emitOpError("operand type mismatch: expected operand type ")
|
|
|
|
<< fnType.getInput(i) << ", but provided "
|
|
|
|
<< getOperand(i).getType() << " for operand number " << i;
|
|
|
|
|
|
|
|
if (fnType.getNumResults() != getNumResults())
|
|
|
|
return emitOpError("incorrect number of results for callee");
|
|
|
|
|
|
|
|
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
|
|
|
if (getResult(i).getType() != fnType.getResult(i)) {
|
|
|
|
auto diag = emitOpError("result type mismatch at index ") << i;
|
|
|
|
diag.attachNote() << " op result types: " << getResultTypes();
|
|
|
|
diag.attachNote() << "function result types: " << fnType.getResults();
|
|
|
|
return diag;
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
FunctionType CallOp::getCalleeType() {
|
|
|
|
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// ReturnOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ReturnOp::verify() {
|
|
|
|
auto funcOp = (*this)->getParentOfType<FuncOp>();
|
|
|
|
ArrayRef<Type> resultTypes = funcOp.isStateful()
|
|
|
|
? funcOp.getResultTypes().drop_front()
|
|
|
|
: funcOp.getResultTypes();
|
|
|
|
// Get the underlying value types from async types returned from the
|
|
|
|
// parent `async.func` operation.
|
|
|
|
auto types = llvm::map_range(resultTypes, [](const Type &result) {
|
2023-05-11 11:10:46 +02:00
|
|
|
return llvm::cast<ValueType>(result).getValueType();
|
2022-11-02 11:27:26 -07:00
|
|
|
});
|
|
|
|
|
|
|
|
if (getOperandTypes() != types)
|
|
|
|
return emitOpError("operand types do not match the types returned from "
|
|
|
|
"the parent FuncOp");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-01-25 14:14:12 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd op method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-09-28 22:47:00 -07:00
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
2021-01-25 14:14:12 -08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd type method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#define GET_TYPEDEF_CLASSES
|
|
|
|
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
|
|
|
|
|
2021-11-11 06:12:06 +00:00
|
|
|
void ValueType::print(AsmPrinter &printer) const {
|
2021-01-25 14:14:12 -08:00
|
|
|
printer << "<";
|
|
|
|
printer.printType(getValueType());
|
|
|
|
printer << '>';
|
|
|
|
}
|
|
|
|
|
2021-11-11 06:12:06 +00:00
|
|
|
Type ValueType::parse(mlir::AsmParser &parser) {
|
2021-01-25 14:14:12 -08:00
|
|
|
Type ty;
|
|
|
|
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
|
|
|
|
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
return ValueType::get(ty);
|
|
|
|
}
|