[Flang][OpenMP] Prevent re-composition of composite constructs (#102613)

After decomposition of OpenMP compound constructs and assignment of
applicable clauses to each leaf construct, composite constructs are then
combined again into a single element in the construct queue. This helped
later lowering stages easily identify composite constructs.

However, as a result of the re-composition stage, the same list of
clauses is used to produce all MLIR operations corresponding to each
leaf of the original composite construct. This undoes existing logic
introducing implicit clauses and deciding to which leaf construct(s)
each clause applies.

This patch removes construct re-composition logic and updates Flang
lowering to be able to identify composite constructs from a list of leaf
constructs. As a result, the right set of clauses is produced for each
operation representing a leaf of a composite construct.

PR stack:
- #102612
- #102613
This commit is contained in:
Sergio Afonso 2024-08-20 11:09:54 +01:00 committed by GitHub
parent ba84cfbe0c
commit aa875cfe11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 105 additions and 496 deletions

View File

@ -22,7 +22,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/ClauseT.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/raw_ostream.h"
@ -68,12 +67,6 @@ struct ConstructDecomposition {
};
} // namespace
static UnitConstruct mergeConstructs(uint32_t version,
llvm::ArrayRef<UnitConstruct> units) {
tomp::ConstructCompositionT compose(version, units);
return compose.merged;
}
namespace Fortran::lower::omp {
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const UnitConstruct &uc) {
@ -90,38 +83,37 @@ ConstructQueue buildConstructQueue(
Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
llvm::omp::Directive compound, const List<Clause> &clauses) {
List<UnitConstruct> constructs;
ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
assert(!decompose.output.empty() && "Construct decomposition failed");
llvm::SmallVector<llvm::omp::Directive> loweringUnits;
std::ignore =
llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
uint32_t version = getOpenMPVersionAttribute(modOp);
int leafIndex = 0;
for (llvm::omp::Directive dir_id : loweringUnits) {
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
llvm::omp::getLeafConstructsOrSelf(dir_id);
size_t numLeafs = leafsOrSelf.size();
llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
numLeafs};
auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));
if (!transferLocations(clauses, uc.clauses)) {
// If some clauses are left without source information, use the
// directive's source.
for (auto &clause : uc.clauses) {
if (clause.source.empty())
clause.source = source;
}
}
leafIndex += numLeafs;
for (UnitConstruct &uc : decompose.output) {
assert(getLeafConstructs(uc.id).empty() && "unexpected compound directive");
// If some clauses are left without source information, use the directive's
// source.
for (auto &clause : uc.clauses)
if (clause.source.empty())
clause.source = source;
}
return constructs;
return decompose.output;
}
bool matchLeafSequence(ConstructQueue::const_iterator item,
const ConstructQueue &queue,
llvm::omp::Directive directive) {
llvm::ArrayRef<llvm::omp::Directive> leafDirs =
llvm::omp::getLeafConstructsOrSelf(directive);
for (auto [dir, leaf] :
llvm::zip_longest(leafDirs, llvm::make_range(item, queue.end()))) {
if (!dir.has_value() || !leaf.has_value())
return false;
if (*dir != leaf->id)
return false;
}
return true;
}
bool isLastItemInQueue(ConstructQueue::const_iterator item,

View File

@ -10,7 +10,6 @@
#include "Clauses.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/Compiler.h"
@ -49,6 +48,15 @@ ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
bool isLastItemInQueue(ConstructQueue::const_iterator item,
const ConstructQueue &queue);
/// Try to match the leaf constructs conforming the given \c directive to the
/// range of leaf constructs starting from \c item to the end of the \c queue.
/// If \c directive doesn't represent a compound directive, check that \c item
/// matches that directive and is the only element before the end of the
/// \c queue.
bool matchLeafSequence(ConstructQueue::const_iterator item,
const ConstructQueue &queue,
llvm::omp::Directive directive);
} // namespace Fortran::lower::omp
#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H

View File

@ -2044,6 +2044,7 @@ static void genCompositeDistributeParallelDoSimd(
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
assert(std::distance(item, queue.end()) == 4 && "Invalid leaf constructs");
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}
@ -2054,17 +2055,23 @@ static void genCompositeDistributeSimd(
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
ConstructQueue::const_iterator distributeItem = item;
ConstructQueue::const_iterator simdItem = std::next(distributeItem);
// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
genDistributeClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
distributeClauseOps);
genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
loc, distributeClauseOps);
mlir::omp::SimdOperands simdClauseOps;
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);
// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
loopNestClauseOps, iv);
// Operation creation.
@ -2086,7 +2093,7 @@ static void genCompositeDistributeSimd(
llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(),
simdOp.getRegion().getArguments()));
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs,
llvm::omp::Directive::OMPD_distribute_simd, dsp);
}
@ -2100,19 +2107,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
ConstructQueue::const_iterator doItem = item;
ConstructQueue::const_iterator simdItem = std::next(doItem);
// Clause processing.
mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
llvm::SmallVector<mlir::Type> wsloopReductionTypes;
genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
mlir::omp::SimdOperands simdClauseOps;
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);
// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
loopNestClauseOps, iv);
// Operation creation.
@ -2133,7 +2146,7 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
auto wrapperArgs = llvm::to_vector(llvm::concat<mlir::BlockArgument>(
wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments()));
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs,
llvm::omp::Directive::OMPD_do_simd, dsp);
}
@ -2143,6 +2156,7 @@ static void genCompositeTaskloopSimd(
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
TODO(loc, "Composite TASKLOOP SIMD");
}
@ -2150,6 +2164,36 @@ static void genCompositeTaskloopSimd(
// Dispatch
//===----------------------------------------------------------------------===//
static bool genOMPCompositeDispatch(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
using llvm::omp::Directive;
using lower::omp::matchLeafSequence;
if (matchLeafSequence(item, queue, Directive::OMPD_distribute_parallel_do))
genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
queue, item, dsp);
else if (matchLeafSequence(item, queue,
Directive::OMPD_distribute_parallel_do_simd))
genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
loc, queue, item, dsp);
else if (matchLeafSequence(item, queue, Directive::OMPD_distribute_simd))
genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
item, dsp);
else if (matchLeafSequence(item, queue, Directive::OMPD_do_simd))
genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
dsp);
else if (matchLeafSequence(item, queue, Directive::OMPD_taskloop_simd))
genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
item, dsp);
else
return false;
return true;
}
static void genOMPDispatch(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
@ -2163,10 +2207,18 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
llvm::omp::Association::Loop;
if (loopLeaf) {
symTable.pushScope();
// TODO: Use one DataSharingProcessor for each leaf of a composite
// construct.
loopDsp.emplace(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
/*useDelayedPrivatization=*/false, &symTable);
loopDsp->processStep1();
if (genOMPCompositeDispatch(converter, symTable, semaCtx, eval, loc, queue,
item, *loopDsp)) {
symTable.popScope();
return;
}
}
switch (llvm::omp::Directive dir = item->id) {
@ -2262,29 +2314,11 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
// that use this construct, add a single construct for now.
genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;
// Composite constructs
case llvm::omp::Directive::OMPD_distribute_parallel_do:
genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
queue, item, *loopDsp);
break;
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
loc, queue, item, *loopDsp);
break;
case llvm::omp::Directive::OMPD_distribute_simd:
genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
item, *loopDsp);
break;
case llvm::omp::Directive::OMPD_do_simd:
genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
*loopDsp);
break;
case llvm::omp::Directive::OMPD_taskloop_simd:
genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
item, *loopDsp);
break;
default:
// Combined and composite constructs should have been split into a sequence
// of leaf constructs when building the construct queue.
assert(!llvm::omp::isLeafConstruct(dir) &&
"Unexpected compound construct.");
break;
}

View File

@ -4,7 +4,7 @@
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
subroutine testDoSimdLinear(int_array)
integer :: int_array(*)
!CHECK: not yet implemented: Unhandled clause LINEAR in DO construct
!CHECK: not yet implemented: Unhandled clause LINEAR in SIMD construct
!$omp do simd linear(int_array)
do index_ = 1, 10
end do

View File

@ -197,9 +197,9 @@ subroutine nested_default_clause_tests
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {

View File

@ -134,9 +134,9 @@ end program default_clause_lowering
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_test1Ez"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_test1Ex"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_test1Ek"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {

View File

@ -1,425 +0,0 @@
//===- ConstructCompositionT.h -- Composing compound constructs -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Given a list of leaf construct, each with a set of clauses, generate the
// compound construct whose leaf constructs are the given list, and whose clause
// list is the merged lists of individual leaf clauses.
//
// *** At the moment it assumes that the individual constructs and their clauses
// *** are a subset of those created by splitting a valid compound construct.
//===----------------------------------------------------------------------===//
#ifndef LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
#define LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/ClauseT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include <iterator>
#include <optional>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
namespace tomp {
template <typename ClauseType> struct ConstructCompositionT {
using ClauseTy = ClauseType;
using TypeTy = typename ClauseTy::TypeTy;
using IdTy = typename ClauseTy::IdTy;
using ExprTy = typename ClauseTy::ExprTy;
ConstructCompositionT(uint32_t version,
llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs);
DirectiveWithClauses<ClauseTy> merged;
private:
// Use an ordered container, since we beed to maintain the order in which
// clauses are added to it. This is to avoid non-deterministic output.
using ClauseSet = ListT<ClauseTy>;
enum class Presence {
All, // Clause is preesnt on all leaf constructs that allow it.
Some, // Clause is present on some, but not on all constructs.
None, // Clause is absent on all constructs.
};
template <typename S>
ClauseTy makeClause(llvm::omp::Clause clauseId, S &&specific) {
return typename ClauseTy::BaseT{clauseId, std::move(specific)};
}
llvm::omp::Directive
makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts);
Presence checkPresence(llvm::omp::Clause clauseId);
// There are clauses that need special handling:
// 1. "if": the "directive-name-modifier" on the merged clause may need
// to be set appropriately.
// 2. "reduction": implies "privateness" of all objects (incompatible
// with "shared"); there are rules for merging modifiers
void mergeIf();
void mergeReduction();
void mergeDSA();
uint32_t version;
llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs;
// clause id -> set of leaf constructs that contain it
std::unordered_map<llvm::omp::Clause, llvm::BitVector> clausePresence;
// clause id -> set of instances of that clause
std::unordered_map<llvm::omp::Clause, ClauseSet> clauseSets;
};
template <typename ClauseTy>
ConstructCompositionT(uint32_t, llvm::ArrayRef<DirectiveWithClauses<ClauseTy>>)
-> ConstructCompositionT<ClauseTy>;
template <typename C>
ConstructCompositionT<C>::ConstructCompositionT(
uint32_t version, llvm::ArrayRef<DirectiveWithClauses<C>> leafs)
: version(version), leafs(leafs) {
// Merge the list of constructs with clauses into a compound construct
// with a single list of clauses.
// The intended use of this function is in splitting compound constructs,
// while preserving composite constituent constructs:
// Step 1: split compound construct into leaf constructs.
// Step 2: identify composite sub-construct, and merge the constituent leafs.
//
// *** At the moment it assumes that the individual constructs and their
// *** clauses are a subset of those created by splitting a valid compound
// *** construct.
//
// 1. Deduplicate clauses
// - exact duplicates: e.g. shared(x) shared(x) -> shared(x)
// - special cases of clauses differing in modifier:
// (a) reduction: inscan + (none|default) = inscan
// (b) reduction: task + (none|default) = task
// (c) combine repeated "if" clauses if possible
// 2. Merge DSA clauses: e.g. private(x) private(y) -> private(x, y).
// 3. Resolve potential DSA conflicts (typically due to implied clauses).
if (leafs.empty())
return;
merged.id = makeCompound(leafs);
// Populate the two maps:
for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
for (const auto &clause : leaf.clauses) {
// Update clausePresence.
auto &pset = clausePresence[clause.id];
if (pset.size() < leafs.size())
pset.resize(leafs.size());
pset.set(index);
// Update clauseSets.
ClauseSet &cset = clauseSets[clause.id];
if (!llvm::is_contained(cset, clause))
cset.push_back(clause);
}
}
mergeIf();
mergeReduction();
mergeDSA();
// For the rest of the clauses, just copy them.
for (auto &[id, clauses] : clauseSets) {
// Skip clauses we've already dealt with.
switch (id) {
case llvm::omp::Clause::OMPC_if:
case llvm::omp::Clause::OMPC_reduction:
case llvm::omp::Clause::OMPC_shared:
case llvm::omp::Clause::OMPC_private:
case llvm::omp::Clause::OMPC_firstprivate:
case llvm::omp::Clause::OMPC_lastprivate:
continue;
default:
break;
}
llvm::append_range(merged.clauses, clauses);
}
}
template <typename C>
llvm::omp::Directive ConstructCompositionT<C>::makeCompound(
llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts) {
llvm::SmallVector<llvm::omp::Directive> dirIds;
llvm::transform(parts, std::back_inserter(dirIds),
[](auto &&dwc) { return dwc.id; });
return llvm::omp::getCompoundConstruct(dirIds);
}
template <typename C>
auto ConstructCompositionT<C>::checkPresence(llvm::omp::Clause clauseId)
-> Presence {
auto found = clausePresence.find(clauseId);
if (found == clausePresence.end())
return Presence::None;
bool OnAll = true, OnNone = true;
for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
if (!llvm::omp::isAllowedClauseForDirective(leaf.id, clauseId, version))
continue;
if (found->second.test(index))
OnNone = false;
else
OnAll = false;
}
if (OnNone)
return Presence::None;
if (OnAll)
return Presence::All;
return Presence::Some;
}
template <typename C> void ConstructCompositionT<C>::mergeIf() {
using IfTy = tomp::clause::IfT<TypeTy, IdTy, ExprTy>;
// Deal with the "if" clauses. If it's on all leafs that allow it, then it
// will apply to the compound construct. Otherwise it will apply to the
// single (assumed) leaf construct.
// This assumes that the "if" clauses have the same expression.
Presence presence = checkPresence(llvm::omp::Clause::OMPC_if);
if (presence == Presence::None)
return;
const ClauseTy &some = *clauseSets[llvm::omp::Clause::OMPC_if].begin();
const auto &someIf = std::get<IfTy>(some.u);
if (presence == Presence::All) {
// Create "if" without "directive-name-modifier".
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_if,
IfTy{{/*DirectiveNameModifier=*/std::nullopt,
/*IfExpression=*/std::get<typename IfTy::IfExpression>(
someIf.t)}}));
} else {
// Find out where it's present and create "if" with the corresponding
// "directive-name-modifier".
int Idx = clausePresence[llvm::omp::Clause::OMPC_if].find_first();
assert(Idx >= 0);
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_if,
IfTy{{/*DirectiveNameModifier=*/leafs[Idx].id,
/*IfExpression=*/std::get<typename IfTy::IfExpression>(
someIf.t)}}));
}
}
template <typename C> void ConstructCompositionT<C>::mergeReduction() {
Presence presence = checkPresence(llvm::omp::Clause::OMPC_reduction);
if (presence == Presence::None)
return;
using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
using ModifierTy = typename ReductionTy::ReductionModifier;
using IdentifiersTy = typename ReductionTy::ReductionIdentifiers;
using ListTy = typename ReductionTy::List;
// There are exceptions on which constructs "reduction" may appear
// (specifically "parallel", and "teams"). Assume that if "reduction"
// is present, it can be applied to the compound construct.
// What's left is to see if there are any modifiers present. Again,
// assume that there are no conflicting modifiers.
// There can be, however, multiple reductions on different objects.
auto equal = [](const ClauseTy &red1, const ClauseTy &red2) {
// Extract actual reductions.
const auto r1 = std::get<ReductionTy>(red1.u);
const auto r2 = std::get<ReductionTy>(red2.u);
// Compare everything except modifiers.
if (std::get<IdentifiersTy>(r1.t) != std::get<IdentifiersTy>(r2.t))
return false;
if (std::get<ListTy>(r1.t) != std::get<ListTy>(r2.t))
return false;
return true;
};
auto getModifier = [](const ClauseTy &clause) {
const ReductionTy &red = std::get<ReductionTy>(clause.u);
return std::get<std::optional<ModifierTy>>(red.t);
};
const ClauseSet &reductions = clauseSets[llvm::omp::Clause::OMPC_reduction];
std::unordered_set<const ClauseTy *> visited;
while (reductions.size() != visited.size()) {
typename ClauseSet::const_iterator first;
// Find first non-visited reduction.
for (first = reductions.begin(); first != reductions.end(); ++first) {
if (visited.count(&*first))
continue;
visited.insert(&*first);
break;
}
std::optional<ModifierTy> modifier = getModifier(*first);
// Visit all other reductions that are "equal" (with respect to the
// definition above) to "first". Collect modifiers.
for (auto iter = std::next(first); iter != reductions.end(); ++iter) {
if (!equal(*first, *iter))
continue;
visited.insert(&*iter);
if (!modifier || *modifier == ModifierTy::Default)
modifier = getModifier(*iter);
}
const auto &firstRed = std::get<ReductionTy>(first->u);
merged.clauses.emplace_back(makeClause(
llvm::omp::Clause::OMPC_reduction,
ReductionTy{
{/*ReductionModifier=*/modifier,
/*ReductionIdentifiers=*/std::get<IdentifiersTy>(firstRed.t),
/*List=*/std::get<ListTy>(firstRed.t)}}));
}
}
template <typename C> void ConstructCompositionT<C>::mergeDSA() {
using ObjectTy = tomp::type::ObjectT<IdTy, ExprTy>;
// Resolve data-sharing attributes.
enum DSA : int {
None = 0,
Shared = 1 << 0,
Private = 1 << 1,
FirstPrivate = 1 << 2,
LastPrivate = 1 << 3,
LastPrivateConditional = 1 << 4,
};
// Use ordered containers to avoid non-deterministic output.
llvm::SmallVector<std::pair<ObjectTy, int>, 8> objectDsa;
auto getDsa = [&](const ObjectTy &object) -> std::pair<ObjectTy, int> & {
auto found = llvm::find_if(objectDsa, [&](std::pair<ObjectTy, int> &p) {
return p.first.id() == object.id();
});
if (found != objectDsa.end())
return *found;
return objectDsa.emplace_back(object, DSA::None);
};
using SharedTy = tomp::clause::SharedT<TypeTy, IdTy, ExprTy>;
using PrivateTy = tomp::clause::PrivateT<TypeTy, IdTy, ExprTy>;
using FirstprivateTy = tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>;
using LastprivateTy = tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>;
// Visit clauses that affect DSA.
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_shared]) {
for (auto &object : std::get<SharedTy>(clause.u).v)
getDsa(object).second |= DSA::Shared;
}
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_private]) {
for (auto &object : std::get<PrivateTy>(clause.u).v)
getDsa(object).second |= DSA::Private;
}
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_firstprivate]) {
for (auto &object : std::get<FirstprivateTy>(clause.u).v)
getDsa(object).second |= DSA::FirstPrivate;
}
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_lastprivate]) {
using ModifierTy = typename LastprivateTy::LastprivateModifier;
using ListTy = typename LastprivateTy::List;
const auto &lastp = std::get<LastprivateTy>(clause.u);
for (auto &object : std::get<ListTy>(lastp.t)) {
auto &mod = std::get<std::optional<ModifierTy>>(lastp.t);
if (mod && *mod == ModifierTy::Conditional) {
getDsa(object).second |= DSA::LastPrivateConditional;
} else {
getDsa(object).second |= DSA::LastPrivate;
}
}
}
// Check other privatizing clauses as well, clear "shared" if set.
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_in_reduction]) {
using InReductionTy = tomp::clause::InReductionT<TypeTy, IdTy, ExprTy>;
using ListTy = typename InReductionTy::List;
for (auto &object : std::get<ListTy>(std::get<InReductionTy>(clause.u).t))
getDsa(object).second &= ~DSA::Shared;
}
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_linear]) {
using LinearTy = tomp::clause::LinearT<TypeTy, IdTy, ExprTy>;
using ListTy = typename LinearTy::List;
for (auto &object : std::get<ListTy>(std::get<LinearTy>(clause.u).t))
getDsa(object).second &= ~DSA::Shared;
}
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_reduction]) {
using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
using ListTy = typename ReductionTy::List;
for (auto &object : std::get<ListTy>(std::get<ReductionTy>(clause.u).t))
getDsa(object).second &= ~DSA::Shared;
}
for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_task_reduction]) {
using TaskReductionTy = tomp::clause::TaskReductionT<TypeTy, IdTy, ExprTy>;
using ListTy = typename TaskReductionTy::List;
for (auto &object : std::get<ListTy>(std::get<TaskReductionTy>(clause.u).t))
getDsa(object).second &= ~DSA::Shared;
}
tomp::ListT<ObjectTy> privateObj, sharedObj, firstpObj, lastpObj, lastpcObj;
for (auto &[object, dsa] : objectDsa) {
if (dsa &
(DSA::FirstPrivate | DSA::LastPrivate | DSA::LastPrivateConditional)) {
if (dsa & DSA::FirstPrivate)
firstpObj.push_back(object); // no else
if (dsa & DSA::LastPrivateConditional)
lastpcObj.push_back(object);
else if (dsa & DSA::LastPrivate)
lastpObj.push_back(object);
} else if (dsa & DSA::Private) {
privateObj.push_back(object);
} else if (dsa & DSA::Shared) {
sharedObj.push_back(object);
}
}
// Materialize each clause.
if (!privateObj.empty()) {
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_private,
PrivateTy{/*List=*/std::move(privateObj)}));
}
if (!sharedObj.empty()) {
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_shared,
SharedTy{/*List=*/std::move(sharedObj)}));
}
if (!firstpObj.empty()) {
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_firstprivate,
FirstprivateTy{/*List=*/std::move(firstpObj)}));
}
if (!lastpObj.empty()) {
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_lastprivate,
LastprivateTy{{/*LastprivateModifier=*/std::nullopt,
/*List=*/std::move(lastpObj)}}));
}
if (!lastpcObj.empty()) {
auto conditional = LastprivateTy::LastprivateModifier::Conditional;
merged.clauses.emplace_back(
makeClause(llvm::omp::Clause::OMPC_lastprivate,
LastprivateTy{{/*LastprivateModifier=*/conditional,
/*List=*/std::move(lastpcObj)}}));
}
}
} // namespace tomp
#endif // LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H