mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-16 07:46:45 +00:00
[OpenACC][CIR] Implement 'device_type' clause lowering for 'init'/'sh… (#135102)
…utdown' This patch emits the lowering for 'device_type' on an 'init' or 'shutdown'. This one is fairly unique, as these directives have it as an attribute, rather than as a component of the individual operands, like the rest of the constructs. So this patch implements the lowering as an attribute. In order to do tis, a few refactorings had to happen: First, the 'emitOpenACCOp' functions needed to pick up th edirective kind/location so that the NYI diagnostic could be reasonable. Second, and most impactful, the `applyAttributes` function ends up needing to encode some of the appertainment rules, thanks to the way the OpenACC-MLIR operands get their attributes attached. Since they each use a special function (rather than something that can be legalized at runtime), the forms of 'setDefaultAttr' is only valid for some ops. SO this patch uses some `if constexpr` and a small type-trait to help legalize these.
This commit is contained in:
parent
dcb9078081
commit
74c2b41feb
@ -604,15 +604,16 @@ public:
|
||||
private:
|
||||
template <typename Op>
|
||||
mlir::LogicalResult
|
||||
emitOpenACCOp(mlir::Location start,
|
||||
emitOpenACCOp(mlir::Location start, OpenACCDirectiveKind dirKind,
|
||||
SourceLocation dirLoc,
|
||||
llvm::ArrayRef<const OpenACCClause *> clauses);
|
||||
// Function to do the basic implementation of an operation with an Associated
|
||||
// Statement. Models AssociatedStmtConstruct.
|
||||
template <typename Op, typename TermOp>
|
||||
mlir::LogicalResult
|
||||
emitOpenACCOpAssociatedStmt(mlir::Location start, mlir::Location end,
|
||||
llvm::ArrayRef<const OpenACCClause *> clauses,
|
||||
const Stmt *associatedStmt);
|
||||
mlir::LogicalResult emitOpenACCOpAssociatedStmt(
|
||||
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
|
||||
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
|
||||
const Stmt *associatedStmt);
|
||||
|
||||
public:
|
||||
mlir::LogicalResult
|
||||
|
@ -9,6 +9,7 @@
|
||||
// Emit OpenACC Stmt nodes as CIR code.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <type_traits>
|
||||
|
||||
#include "CIRGenBuilder.h"
|
||||
#include "CIRGenFunction.h"
|
||||
@ -23,14 +24,29 @@ using namespace cir;
|
||||
using namespace mlir::acc;
|
||||
|
||||
namespace {
|
||||
// Simple type-trait to see if the first template arg is one of the list, so we
|
||||
// can tell whether to `if-constexpr` a bunch of stuff.
|
||||
template <typename ToTest, typename T, typename... Tys>
|
||||
constexpr bool isOneOfTypes =
|
||||
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
|
||||
template <typename ToTest, typename T>
|
||||
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
|
||||
|
||||
class OpenACCClauseCIREmitter final
|
||||
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
|
||||
CIRGenModule &cgm;
|
||||
// This is necessary since a few of the clauses emit differently based on the
|
||||
// directive kind they are attached to.
|
||||
OpenACCDirectiveKind dirKind;
|
||||
SourceLocation dirLoc;
|
||||
|
||||
struct AttributeData {
|
||||
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
|
||||
// constructs as a 'default-attr'.
|
||||
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
|
||||
// For directives that have their device type architectures listed in
|
||||
// attributes (init/shutdown/etc), the list of architectures to be emitted.
|
||||
llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
|
||||
} attrData;
|
||||
|
||||
void clauseNotImplemented(const OpenACCClause &c) {
|
||||
@ -38,7 +54,9 @@ class OpenACCClauseCIREmitter final
|
||||
}
|
||||
|
||||
public:
|
||||
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {}
|
||||
OpenACCClauseCIREmitter(CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
|
||||
SourceLocation dirLoc)
|
||||
: cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
|
||||
|
||||
void VisitClause(const OpenACCClause &clause) {
|
||||
clauseNotImplemented(clause);
|
||||
@ -57,31 +75,92 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
|
||||
// '*' case leaves no identifier-info, just a nullptr.
|
||||
if (!ii)
|
||||
return mlir::acc::DeviceType::Star;
|
||||
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
|
||||
.CaseLower("default", mlir::acc::DeviceType::Default)
|
||||
.CaseLower("host", mlir::acc::DeviceType::Host)
|
||||
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
|
||||
.CasesLower("nvidia", "acc_device_nvidia",
|
||||
mlir::acc::DeviceType::Nvidia)
|
||||
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
|
||||
}
|
||||
|
||||
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
|
||||
|
||||
switch (dirKind) {
|
||||
case OpenACCDirectiveKind::Init:
|
||||
case OpenACCDirectiveKind::Shutdown: {
|
||||
// Device type has a list that is either a 'star' (emitted as 'star'),
|
||||
// or an identifer list, all of which get added for attributes.
|
||||
|
||||
for (const DeviceTypeArgument &arg : clause.getArchitectures())
|
||||
attrData.deviceTypeArchs.push_back(decodeDeviceType(arg.first));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return clauseNotImplemented(clause);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply any of the clauses that resulted in an 'attribute'.
|
||||
template <typename Op> void applyAttributes(Op &op) {
|
||||
if (attrData.defaultVal.has_value())
|
||||
op.setDefaultAttr(*attrData.defaultVal);
|
||||
template <typename Op>
|
||||
void applyAttributes(CIRGenBuilderTy &builder, Op &op) {
|
||||
|
||||
if (attrData.defaultVal.has_value()) {
|
||||
// FIXME: OpenACC: as we implement this for other directive kinds, we have
|
||||
// to expand this list.
|
||||
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
|
||||
// operations listed in the rest of the arguments.
|
||||
if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
|
||||
op.setDefaultAttr(*attrData.defaultVal);
|
||||
else
|
||||
cgm.errorNYI(dirLoc, "OpenACC 'default' clause lowering for ", dirKind);
|
||||
}
|
||||
|
||||
if (!attrData.deviceTypeArchs.empty()) {
|
||||
// FIXME: OpenACC: as we implement this for other directive kinds, we have
|
||||
// to expand this list, or more likely, have a 'noop' branch as most other
|
||||
// uses of this apply to the operands instead.
|
||||
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
|
||||
if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
|
||||
llvm::SmallVector<mlir::Attribute> deviceTypes;
|
||||
for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs)
|
||||
deviceTypes.push_back(
|
||||
mlir::acc::DeviceTypeAttr::get(builder.getContext(), DT));
|
||||
|
||||
op.setDeviceTypesAttr(
|
||||
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
|
||||
} else {
|
||||
cgm.errorNYI(dirLoc, "OpenACC 'device_type' clause lowering for ",
|
||||
dirKind);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Op, typename TermOp>
|
||||
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
|
||||
mlir::Location start, mlir::Location end,
|
||||
llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
|
||||
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
|
||||
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
|
||||
const Stmt *associatedStmt) {
|
||||
mlir::LogicalResult res = mlir::success();
|
||||
|
||||
llvm::SmallVector<mlir::Type> retTy;
|
||||
llvm::SmallVector<mlir::Value> operands;
|
||||
|
||||
// Clause-emitter must be here because it might modify operands.
|
||||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
|
||||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
|
||||
clauseEmitter.VisitClauseList(clauses);
|
||||
|
||||
auto op = builder.create<Op>(start, retTy, operands);
|
||||
|
||||
// Apply the attributes derived from the clauses.
|
||||
clauseEmitter.applyAttributes(op);
|
||||
clauseEmitter.applyAttributes(builder, op);
|
||||
|
||||
mlir::Block &block = op.getRegion().emplaceBlock();
|
||||
mlir::OpBuilder::InsertionGuard guardCase(builder);
|
||||
@ -95,19 +174,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
mlir::LogicalResult
|
||||
CIRGenFunction::emitOpenACCOp(mlir::Location start,
|
||||
llvm::ArrayRef<const OpenACCClause *> clauses) {
|
||||
mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
|
||||
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
|
||||
llvm::ArrayRef<const OpenACCClause *> clauses) {
|
||||
mlir::LogicalResult res = mlir::success();
|
||||
|
||||
llvm::SmallVector<mlir::Type> retTy;
|
||||
llvm::SmallVector<mlir::Value> operands;
|
||||
|
||||
// Clause-emitter must be here because it might modify operands.
|
||||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
|
||||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
|
||||
clauseEmitter.VisitClauseList(clauses);
|
||||
|
||||
builder.create<Op>(start, retTy, operands);
|
||||
auto op = builder.create<Op>(start, retTy, operands);
|
||||
// Apply the attributes derived from the clauses.
|
||||
clauseEmitter.applyAttributes(builder, op);
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -119,13 +200,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
|
||||
switch (s.getDirectiveKind()) {
|
||||
case OpenACCDirectiveKind::Parallel:
|
||||
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
|
||||
start, end, s.clauses(), s.getStructuredBlock());
|
||||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
|
||||
s.getStructuredBlock());
|
||||
case OpenACCDirectiveKind::Serial:
|
||||
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
|
||||
start, end, s.clauses(), s.getStructuredBlock());
|
||||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
|
||||
s.getStructuredBlock());
|
||||
case OpenACCDirectiveKind::Kernels:
|
||||
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
|
||||
start, end, s.clauses(), s.getStructuredBlock());
|
||||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
|
||||
s.getStructuredBlock());
|
||||
default:
|
||||
llvm_unreachable("invalid compute construct kind");
|
||||
}
|
||||
@ -137,18 +221,22 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
|
||||
mlir::Location end = getLoc(s.getSourceRange().getEnd());
|
||||
|
||||
return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
|
||||
start, end, s.clauses(), s.getStructuredBlock());
|
||||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
|
||||
s.getStructuredBlock());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
|
||||
mlir::Location start = getLoc(s.getSourceRange().getEnd());
|
||||
return emitOpenACCOp<InitOp>(start, s.clauses());
|
||||
return emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
|
||||
s.clauses());
|
||||
}
|
||||
|
||||
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
|
||||
const OpenACCShutdownConstruct &s) {
|
||||
mlir::Location start = getLoc(s.getSourceRange().getEnd());
|
||||
return emitOpenACCOp<ShutdownOp>(start, s.clauses());
|
||||
return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
|
||||
s.getDirectiveLoc(), s.clauses());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
|
@ -4,4 +4,17 @@ void acc_init(void) {
|
||||
// CHECK: cir.func @acc_init() {
|
||||
#pragma acc init
|
||||
// CHECK-NEXT: acc.init loc(#{{[a-zA-Z0-9]+}}){{$}}
|
||||
|
||||
#pragma acc init device_type(*)
|
||||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<star>]}
|
||||
#pragma acc init device_type(nvidia)
|
||||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
|
||||
#pragma acc init device_type(host, multicore)
|
||||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
|
||||
#pragma acc init device_type(NVIDIA)
|
||||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
|
||||
#pragma acc init device_type(HoSt, MuLtIcORe)
|
||||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
|
||||
#pragma acc init device_type(HoSt) device_type(MuLtIcORe)
|
||||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
|
||||
}
|
||||
|
@ -4,4 +4,17 @@ void acc_shutdown(void) {
|
||||
// CHECK: cir.func @acc_shutdown() {
|
||||
#pragma acc shutdown
|
||||
// CHECK-NEXT: acc.shutdown loc(#{{[a-zA-Z0-9]+}}){{$}}
|
||||
|
||||
#pragma acc shutdown device_type(*)
|
||||
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<star>]}
|
||||
#pragma acc shutdown device_type(nvidia)
|
||||
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
|
||||
#pragma acc shutdown device_type(host, multicore)
|
||||
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
|
||||
#pragma acc shutdown device_type(NVIDIA)
|
||||
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
|
||||
#pragma acc shutdown device_type(HoSt, MuLtIcORe)
|
||||
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
|
||||
#pragma acc shutdown device_type(HoSt) device_type(MuLtIcORe)
|
||||
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user