mirror of
https://github.com/llvm/llvm-project.git
synced 2025-05-02 02:46:07 +00:00

Depends On D89958 1. Adds `async.group`/`async.awaitall` to group together multiple async tokens/values 2. Rewrite scf.parallel operation into multiple concurrent async.execute operations over non overlapping subranges of the original loop. Example: ``` scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { "do_some_compute"(%i, %j): () -> () } ``` Converted to: ``` %c0 = constant 0 : index %c1 = constant 1 : index // Compute blocks sizes for each induction variable. %num_blocks_i = ... : index %num_blocks_j = ... : index %block_size_i = ... : index %block_size_j = ... : index // Create an async group to track async execute ops. %group = async.create_group scf.for %bi = %c0 to %num_blocks_i step %c1 { %block_start_i = ... : index %block_end_i = ... : index scf.for %bj = %c0 t0 %num_blocks_j step %c1 { %block_start_j = ... : index %block_end_j = ... : index // Execute the body of original parallel operation for the current // block. %token = async.execute { scf.for %i = %block_start_i to %block_end_i step %si { scf.for %j = %block_start_j to %block_end_j step %sj { "do_some_compute"(%i, %j): () -> () } } } // Add produced async token to the group. async.add_to_group %token, %group } } // Await completion of all async.execute operations. async.await_all %group ``` In this example outer loop launches inner block level loops as separate async execute operations which will be executed concurrently. At the end it waits for the completiom of all async execute operations. Reviewed By: ftynse, mehdi_amini Differential Revision: https://reviews.llvm.org/D89963
367 lines
12 KiB
C++
367 lines
12 KiB
C++
//===- Async.cpp - MLIR Async Operations ----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Async/IR/Async.h"
|
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::async;
|
|
|
|
void AsyncDialect::initialize() {
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
|
>();
|
|
addTypes<TokenType>();
|
|
addTypes<ValueType>();
|
|
addTypes<GroupType>();
|
|
}
|
|
|
|
/// Parse a type registered to this dialect.
|
|
Type AsyncDialect::parseType(DialectAsmParser &parser) const {
|
|
StringRef keyword;
|
|
if (parser.parseKeyword(&keyword))
|
|
return Type();
|
|
|
|
if (keyword == "token")
|
|
return TokenType::get(getContext());
|
|
|
|
if (keyword == "value") {
|
|
Type ty;
|
|
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
|
|
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
|
|
return Type();
|
|
}
|
|
return ValueType::get(ty);
|
|
}
|
|
|
|
parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
|
|
return Type();
|
|
}
|
|
|
|
/// Print a type registered to this dialect.
|
|
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|
TypeSwitch<Type>(type)
|
|
.Case<TokenType>([&](TokenType) { os << "token"; })
|
|
.Case<ValueType>([&](ValueType valueTy) {
|
|
os << "value<";
|
|
os.printType(valueTy.getValueType());
|
|
os << '>';
|
|
})
|
|
.Case<GroupType>([&](GroupType) { os << "group"; })
|
|
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// ValueType
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
namespace async {
|
|
namespace detail {
|
|
|
|
// Storage for `async.value<T>` type, the only member is the wrapped type.
|
|
struct ValueTypeStorage : public TypeStorage {
|
|
ValueTypeStorage(Type valueType) : valueType(valueType) {}
|
|
|
|
/// The hash key used for uniquing.
|
|
using KeyTy = Type;
|
|
bool operator==(const KeyTy &key) const { return key == valueType; }
|
|
|
|
/// Construction.
|
|
static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
|
|
Type valueType) {
|
|
return new (allocator.allocate<ValueTypeStorage>())
|
|
ValueTypeStorage(valueType);
|
|
}
|
|
|
|
Type valueType;
|
|
};
|
|
|
|
} // namespace detail
|
|
} // namespace async
|
|
} // namespace mlir
|
|
|
|
ValueType ValueType::get(Type valueType) {
|
|
return Base::get(valueType.getContext(), valueType);
|
|
}
|
|
|
|
Type ValueType::getValueType() { return getImpl()->valueType; }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// YieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verify(YieldOp op) {
|
|
// Get the underlying value types from async values returned from the
|
|
// parent `async.execute` operation.
|
|
auto executeOp = op.getParentOfType<ExecuteOp>();
|
|
auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
|
|
return result.getType().cast<ValueType>().getValueType();
|
|
});
|
|
|
|
if (op.getOperandTypes() != types)
|
|
return op.emitOpError("operand types do not match the types returned from "
|
|
"the parent ExecuteOp");
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// ExecuteOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
|
|
|
|
void ExecuteOp::getNumRegionInvocations(
|
|
ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
|
|
(void)operands;
|
|
assert(countPerRegion.empty());
|
|
countPerRegion.push_back(1);
|
|
}
|
|
|
|
void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
|
|
ArrayRef<Attribute> operands,
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
// The `body` region branch back to the parent operation.
|
|
if (index.hasValue()) {
|
|
assert(*index == 0);
|
|
regions.push_back(RegionSuccessor(getResults()));
|
|
return;
|
|
}
|
|
|
|
// Otherwise the successor is the body region.
|
|
regions.push_back(RegionSuccessor(&body()));
|
|
}
|
|
|
|
void ExecuteOp::build(OpBuilder &builder, OperationState &result,
|
|
TypeRange resultTypes, ValueRange dependencies,
|
|
ValueRange operands, BodyBuilderFn bodyBuilder) {
|
|
|
|
result.addOperands(dependencies);
|
|
result.addOperands(operands);
|
|
|
|
// Add derived `operand_segment_sizes` attribute based on parsed operands.
|
|
int32_t numDependencies = dependencies.size();
|
|
int32_t numOperands = operands.size();
|
|
auto operandSegmentSizes = DenseIntElementsAttr::get(
|
|
VectorType::get({2}, IntegerType::get(32, result.getContext())),
|
|
{numDependencies, numOperands});
|
|
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
|
|
|
// First result is always a token, and then `resultTypes` wrapped into
|
|
// `async.value`.
|
|
result.addTypes({TokenType::get(result.getContext())});
|
|
for (Type type : resultTypes)
|
|
result.addTypes(ValueType::get(type));
|
|
|
|
// Add a body region with block arguments as unwrapped async value operands.
|
|
Region *bodyRegion = result.addRegion();
|
|
bodyRegion->push_back(new Block);
|
|
Block &bodyBlock = bodyRegion->front();
|
|
for (Value operand : operands) {
|
|
auto valueType = operand.getType().dyn_cast<ValueType>();
|
|
bodyBlock.addArgument(valueType ? valueType.getValueType()
|
|
: operand.getType());
|
|
}
|
|
|
|
// Create the default terminator if the builder is not provided and if the
|
|
// expected result is empty. Otherwise, leave this to the caller
|
|
// because we don't know which values to return from the execute op.
|
|
if (resultTypes.empty() && !bodyBuilder) {
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
builder.setInsertionPointToStart(&bodyBlock);
|
|
builder.create<async::YieldOp>(result.location, ValueRange());
|
|
} else if (bodyBuilder) {
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
builder.setInsertionPointToStart(&bodyBlock);
|
|
bodyBuilder(builder, result.location, bodyBlock.getArguments());
|
|
}
|
|
}
|
|
|
|
static void print(OpAsmPrinter &p, ExecuteOp op) {
|
|
p << op.getOperationName();
|
|
|
|
// [%tokens,...]
|
|
if (!op.dependencies().empty())
|
|
p << " [" << op.dependencies() << "]";
|
|
|
|
// (%value as %unwrapped: !async.value<!arg.type>, ...)
|
|
if (!op.operands().empty()) {
|
|
p << " (";
|
|
llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
|
|
p << operand << " as " << op.body().front().getArgument(n++) << ": "
|
|
<< operand.getType();
|
|
});
|
|
p << ")";
|
|
}
|
|
|
|
// -> (!async.value<!return.type>, ...)
|
|
p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
|
|
p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
|
|
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
|
|
MLIRContext *ctx = result.getContext();
|
|
|
|
// Sizes of parsed variadic operands, will be updated below after parsing.
|
|
int32_t numDependencies = 0;
|
|
int32_t numOperands = 0;
|
|
|
|
auto tokenTy = TokenType::get(ctx);
|
|
|
|
// Parse dependency tokens.
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
|
|
if (parser.parseOperandList(tokenArgs) ||
|
|
parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
|
|
numDependencies = tokenArgs.size();
|
|
}
|
|
|
|
// Parse async value operands (%value as %unwrapped : !async.value<!type>).
|
|
SmallVector<OpAsmParser::OperandType, 4> valueArgs;
|
|
SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
|
|
SmallVector<Type, 4> valueTypes;
|
|
SmallVector<Type, 4> unwrappedTypes;
|
|
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
auto argsLoc = parser.getCurrentLocation();
|
|
|
|
// Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
|
|
auto parseAsyncValueArg = [&]() -> ParseResult {
|
|
if (parser.parseOperand(valueArgs.emplace_back()) ||
|
|
parser.parseKeyword("as") ||
|
|
parser.parseOperand(unwrappedArgs.emplace_back()) ||
|
|
parser.parseColonType(valueTypes.emplace_back()))
|
|
return failure();
|
|
|
|
auto valueTy = valueTypes.back().dyn_cast<ValueType>();
|
|
unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
|
|
|
|
return success();
|
|
};
|
|
|
|
// If the next token is `)` skip async value arguments parsing.
|
|
if (failed(parser.parseOptionalRParen())) {
|
|
do {
|
|
if (parseAsyncValueArg())
|
|
return failure();
|
|
} while (succeeded(parser.parseOptionalComma()));
|
|
|
|
if (parser.parseRParen() ||
|
|
parser.resolveOperands(valueArgs, valueTypes, argsLoc,
|
|
result.operands))
|
|
return failure();
|
|
}
|
|
|
|
numOperands = valueArgs.size();
|
|
}
|
|
|
|
// Add derived `operand_segment_sizes` attribute based on parsed operands.
|
|
auto operandSegmentSizes = DenseIntElementsAttr::get(
|
|
VectorType::get({2}, parser.getBuilder().getI32Type()),
|
|
{numDependencies, numOperands});
|
|
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
|
|
|
// Parse the types of results returned from the async execute op.
|
|
SmallVector<Type, 4> resultTypes;
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
return failure();
|
|
|
|
// Async execute first result is always a completion token.
|
|
parser.addTypeToList(tokenTy, result.types);
|
|
parser.addTypesToList(resultTypes, result.types);
|
|
|
|
// Parse operation attributes.
|
|
NamedAttrList attrs;
|
|
if (parser.parseOptionalAttrDictWithKeyword(attrs))
|
|
return failure();
|
|
result.addAttributes(attrs);
|
|
|
|
// Parse asynchronous region.
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
|
|
/*argTypes=*/{unwrappedTypes},
|
|
/*enableNameShadowing=*/false))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verify(ExecuteOp op) {
|
|
// Unwrap async.execute value operands types.
|
|
auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
|
|
return operand.getType().cast<ValueType>().getValueType();
|
|
});
|
|
|
|
// Verify that unwrapped argument types matches the body region arguments.
|
|
if (op.body().getArgumentTypes() != unwrappedTypes)
|
|
return op.emitOpError("async body region argument types do not match the "
|
|
"execute operation arguments types");
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// AwaitOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
result.addOperands({operand});
|
|
result.attributes.append(attrs.begin(), attrs.end());
|
|
|
|
// Add unwrapped async.value type to the returned values types.
|
|
if (auto valueType = operand.getType().dyn_cast<ValueType>())
|
|
result.addTypes(valueType.getValueType());
|
|
}
|
|
|
|
static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
|
|
Type &resultType) {
|
|
if (parser.parseType(operandType))
|
|
return failure();
|
|
|
|
// Add unwrapped async.value type to the returned values types.
|
|
if (auto valueType = operandType.dyn_cast<ValueType>())
|
|
resultType = valueType.getValueType();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
|
|
Type operandType, Type resultType) {
|
|
p << operandType;
|
|
}
|
|
|
|
static LogicalResult verify(AwaitOp op) {
|
|
Type argType = op.operand().getType();
|
|
|
|
// Awaiting on a token does not have any results.
|
|
if (argType.isa<TokenType>() && !op.getResultTypes().empty())
|
|
return op.emitOpError("awaiting on a token must have empty result");
|
|
|
|
// Awaiting on a value unwraps the async value type.
|
|
if (auto value = argType.dyn_cast<ValueType>()) {
|
|
if (*op.getResultType() != value.getValueType())
|
|
return op.emitOpError()
|
|
<< "result type " << *op.getResultType()
|
|
<< " does not match async value type " << value.getValueType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|