mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-19 02:56:49 +00:00
[mlir][flang][openmp] Rework parallel reduction operations (#79308)
This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the `parallel` operation, the change for the `wsloop` operation will be in a separate patch. --------- Co-authored-by: Kiran Chandramohan <kiran.chandramohan@arm.com>
This commit is contained in:
parent
1114ac4399
commit
9ecf4d20bb
@ -621,10 +621,12 @@ public:
|
||||
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
|
||||
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
||||
*mapSymbols = nullptr) const;
|
||||
bool processReduction(
|
||||
mlir::Location currentLocation,
|
||||
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
|
||||
bool
|
||||
processReduction(mlir::Location currentLocation,
|
||||
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
|
||||
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
||||
*reductionSymbols = nullptr) const;
|
||||
bool processSectionsReduction(mlir::Location currentLocation) const;
|
||||
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
|
||||
bool
|
||||
@ -1079,12 +1081,14 @@ public:
|
||||
|
||||
/// Creates a reduction declaration and associates it with an OpenMP block
|
||||
/// directive.
|
||||
static void addReductionDecl(
|
||||
mlir::Location currentLocation,
|
||||
Fortran::lower::AbstractConverter &converter,
|
||||
const Fortran::parser::OmpReductionClause &reduction,
|
||||
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
|
||||
static void
|
||||
addReductionDecl(mlir::Location currentLocation,
|
||||
Fortran::lower::AbstractConverter &converter,
|
||||
const Fortran::parser::OmpReductionClause &reduction,
|
||||
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
|
||||
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
||||
*reductionSymbols = nullptr) {
|
||||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||
mlir::omp::ReductionDeclareOp decl;
|
||||
const auto &redOperator{
|
||||
@ -1114,6 +1118,8 @@ public:
|
||||
if (const auto *name{
|
||||
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
|
||||
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
|
||||
if (reductionSymbols)
|
||||
reductionSymbols->push_back(symbol);
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
|
||||
symVal = declOp.getBase();
|
||||
@ -1148,6 +1154,8 @@ public:
|
||||
if (const auto *name{
|
||||
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
|
||||
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
|
||||
if (reductionSymbols)
|
||||
reductionSymbols->push_back(symbol);
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
|
||||
symVal = declOp.getBase();
|
||||
@ -1948,13 +1956,16 @@ bool ClauseProcessor::processMap(
|
||||
bool ClauseProcessor::processReduction(
|
||||
mlir::Location currentLocation,
|
||||
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
|
||||
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
|
||||
const {
|
||||
return findRepeatableClause<ClauseTy::Reduction>(
|
||||
[&](const ClauseTy::Reduction *reductionClause,
|
||||
const Fortran::parser::CharBlock &) {
|
||||
ReductionProcessor rp;
|
||||
rp.addReductionDecl(currentLocation, converter, reductionClause->v,
|
||||
reductionVars, reductionDeclSymbols);
|
||||
reductionVars, reductionDeclSymbols,
|
||||
reductionSymbols);
|
||||
});
|
||||
}
|
||||
|
||||
@ -2304,6 +2315,14 @@ struct OpWithBodyGenInfo {
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpWithBodyGenInfo &
|
||||
setReductions(llvm::SmallVector<const Fortran::semantics::Symbol *> *value1,
|
||||
llvm::SmallVector<mlir::Type> *value2) {
|
||||
reductionSymbols = value1;
|
||||
reductionTypes = value2;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) {
|
||||
genRegionEntryCB = value;
|
||||
return *this;
|
||||
@ -2323,6 +2342,11 @@ struct OpWithBodyGenInfo {
|
||||
const Fortran::parser::OmpClauseList *clauses = nullptr;
|
||||
/// [in] if provided, processes the construct's data-sharing attributes.
|
||||
DataSharingProcessor *dsp = nullptr;
|
||||
/// [in] if provided, list of reduction symbols
|
||||
llvm::SmallVector<const Fortran::semantics::Symbol *> *reductionSymbols =
|
||||
nullptr;
|
||||
/// [in] if provided, list of reduction types
|
||||
llvm::SmallVector<mlir::Type> *reductionTypes = nullptr;
|
||||
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
|
||||
/// is created in the region.
|
||||
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
|
||||
@ -2567,6 +2591,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
|
||||
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
|
||||
reductionVars;
|
||||
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
|
||||
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
|
||||
|
||||
ClauseProcessor cp(converter, clauseList);
|
||||
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
|
||||
@ -2576,13 +2601,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
|
||||
cp.processDefault();
|
||||
cp.processAllocate(allocatorOperands, allocateOperands);
|
||||
if (!outerCombined)
|
||||
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
|
||||
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
|
||||
&reductionSymbols);
|
||||
|
||||
llvm::SmallVector<mlir::Type> reductionTypes;
|
||||
reductionTypes.reserve(reductionVars.size());
|
||||
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
|
||||
[](mlir::Value v) { return v.getType(); });
|
||||
|
||||
auto reductionCallback = [&](mlir::Operation *op) {
|
||||
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
|
||||
currentLocation);
|
||||
auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
|
||||
reductionTypes, locs);
|
||||
for (auto [arg, prv] :
|
||||
llvm::zip_equal(reductionSymbols, block->getArguments())) {
|
||||
converter.bindSymbol(*arg, prv);
|
||||
}
|
||||
return reductionSymbols;
|
||||
};
|
||||
|
||||
return genOpWithBody<mlir::omp::ParallelOp>(
|
||||
OpWithBodyGenInfo(converter, currentLocation, eval)
|
||||
.setGenNested(genNested)
|
||||
.setOuterCombined(outerCombined)
|
||||
.setClauses(&clauseList),
|
||||
.setClauses(&clauseList)
|
||||
.setReductions(&reductionSymbols, &reductionTypes)
|
||||
.setGenRegionEntryCb(reductionCallback),
|
||||
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
|
||||
numThreadsClauseOperand, allocateOperands, allocatorOperands,
|
||||
reductionVars,
|
||||
@ -3634,10 +3679,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
break;
|
||||
}
|
||||
|
||||
if (singleDirective) {
|
||||
genOpenMPReduction(converter, beginClauseList);
|
||||
if (singleDirective)
|
||||
return;
|
||||
}
|
||||
|
||||
// Codegen for combined directives
|
||||
bool combinedDirective = false;
|
||||
@ -3673,7 +3716,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
")");
|
||||
|
||||
genNestedEvaluations(converter, eval);
|
||||
genOpenMPReduction(converter, beginClauseList);
|
||||
}
|
||||
|
||||
static void
|
||||
|
@ -27,9 +27,11 @@
|
||||
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
|
||||
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
|
||||
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
|
||||
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
|
||||
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
|
||||
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
|
||||
!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
|
||||
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
|
||||
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
@ -48,9 +50,11 @@ end subroutine
|
||||
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
|
||||
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
|
||||
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
|
||||
!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
|
||||
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
@ -72,11 +76,15 @@ end subroutine
|
||||
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
|
||||
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
|
||||
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
|
||||
!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
|
||||
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
|
||||
!CHECK: fir.store %[[RES1]] to %[[PRV1]]
|
||||
!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
|
||||
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
|
||||
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
|
||||
!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
|
||||
!CHECK: fir.store %[[RES0]] to %[[PRV0]]
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
|
@ -28,9 +28,12 @@
|
||||
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
|
||||
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
|
||||
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
|
||||
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
|
||||
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
@ -50,9 +53,12 @@ end subroutine
|
||||
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
|
||||
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
|
||||
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
|
||||
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
|
||||
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
|
||||
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
|
||||
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
|
||||
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
@ -76,11 +82,17 @@ end subroutine
|
||||
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
|
||||
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
|
||||
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
|
||||
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
|
||||
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
|
||||
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
|
||||
!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
|
||||
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
|
||||
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
|
||||
!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
|
38
flang/test/Lower/OpenMP/parallel-reduction.f90
Normal file
38
flang/test/Lower/OpenMP/parallel-reduction.f90
Normal file
@ -0,0 +1,38 @@
|
||||
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
|
||||
|
||||
!CHECK: omp.reduction.declare @[[REDUCTION_DECLARE:[_a-z0-9]+]] : i32 init {
|
||||
!CHECK: ^bb0(%{{.*}}: i32):
|
||||
!CHECK: %[[I0:[_a-z0-9]+]] = arith.constant 0 : i32
|
||||
!CHECK: omp.yield(%[[I0]] : i32)
|
||||
!CHECK: } combiner {
|
||||
!CHECK: ^bb0(%[[C0:[_a-z0-9]+]]: i32, %[[C1:[_a-z0-9]+]]: i32):
|
||||
!CHECK: %[[CR:[_a-z0-9]+]] = arith.addi %[[C0]], %[[C1]] : i32
|
||||
!CHECK: omp.yield(%[[CR]] : i32)
|
||||
!CHECK: }
|
||||
!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "mn"} {
|
||||
!CHECK: %[[RED_ACCUM_REF:[_a-z0-9]+]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
|
||||
!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
|
||||
!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
|
||||
!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
|
||||
!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: %[[RED_ACCUM_VAL:[_a-z0-9]+]] = fir.load %[[RED_ACCUM_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: {{.*}} = fir.call @_FortranAioOutputInteger32(%{{.*}}, %[[RED_ACCUM_VAL]]) fastmath<contract> : (!fir.ref<i8>, i32) -> i1
|
||||
!CHECK: return
|
||||
!CHECK: }
|
||||
|
||||
program mn
|
||||
integer :: i
|
||||
i = 0
|
||||
|
||||
!$omp parallel reduction(+:i)
|
||||
i = 1
|
||||
!$omp end parallel
|
||||
|
||||
print *, i
|
||||
end program
|
@ -191,11 +191,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
unsigned getNumReductionVars() { return getReductionVars().size(); }
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
oilist( `reduction` `(`
|
||||
custom<ReductionVarList>(
|
||||
$reduction_vars, type($reduction_vars), $reductions
|
||||
) `)`
|
||||
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
|
||||
oilist(
|
||||
`if` `(` $if_expr_var `:` type($if_expr_var) `)`
|
||||
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
|
||||
| `allocate` `(`
|
||||
custom<AllocateAndAllocator>(
|
||||
@ -203,7 +200,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
$allocators_vars, type($allocators_vars)
|
||||
) `)`
|
||||
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
|
||||
) $region attr-dict
|
||||
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "mlir/Interfaces/FoldInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/STLForwardCompat.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
@ -34,6 +35,7 @@
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::omp;
|
||||
@ -427,6 +429,71 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
|
||||
// Parser, printer and verifier for ReductionVarList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult
|
||||
parseReductionClause(OpAsmParser &parser, Region ®ion,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
|
||||
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
|
||||
SmallVectorImpl<OpAsmParser::Argument> &privates) {
|
||||
if (failed(parser.parseOptionalKeyword("reduction")))
|
||||
return failure();
|
||||
|
||||
SmallVector<SymbolRefAttr> reductionVec;
|
||||
|
||||
if (failed(
|
||||
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
|
||||
if (parser.parseAttribute(reductionVec.emplace_back()) ||
|
||||
parser.parseOperand(operands.emplace_back()) ||
|
||||
parser.parseArrow() ||
|
||||
parser.parseArgument(privates.emplace_back()) ||
|
||||
parser.parseColonType(types.emplace_back()))
|
||||
return failure();
|
||||
return success();
|
||||
})))
|
||||
return failure();
|
||||
|
||||
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
|
||||
prv.type = type;
|
||||
}
|
||||
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
|
||||
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printReductionClause(OpAsmPrinter &p, Operation *op, Region ®ion,
|
||||
ValueRange operands, TypeRange types,
|
||||
ArrayAttr reductionSymbols) {
|
||||
p << "reduction(";
|
||||
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
|
||||
region.front().getArguments(), types),
|
||||
p, [&p](auto t) {
|
||||
auto [sym, op, arg, type] = t;
|
||||
p << sym << " " << op << " -> " << arg << " : "
|
||||
<< type;
|
||||
});
|
||||
p << ") ";
|
||||
}
|
||||
|
||||
static ParseResult
|
||||
parseParallelRegion(OpAsmParser &parser, Region ®ion,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
|
||||
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
|
||||
|
||||
llvm::SmallVector<OpAsmParser::Argument> privates;
|
||||
if (succeeded(parseReductionClause(parser, region, operands, types,
|
||||
reductionSymbols, privates)))
|
||||
return parser.parseRegion(region, privates);
|
||||
|
||||
return parser.parseRegion(region);
|
||||
}
|
||||
|
||||
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
|
||||
ValueRange operands, TypeRange types,
|
||||
ArrayAttr reductionSymbols) {
|
||||
if (reductionSymbols)
|
||||
printReductionClause(p, op, region, operands, types, reductionSymbols);
|
||||
p.printRegion(region, /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
/// reduction-entry-list ::= reduction-entry
|
||||
/// | reduction-entry-list `,` reduction-entry
|
||||
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
|
||||
@ -1114,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region ®ion,
|
||||
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
|
||||
for (auto &iv : ivs)
|
||||
iv.type = loopVarType;
|
||||
|
||||
return parser.parseRegion(region, ivs);
|
||||
}
|
||||
|
||||
|
@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
|
||||
// Allocate reduction vars
|
||||
SmallVector<llvm::Value *> privateReductionVariables;
|
||||
DenseMap<Value, llvm::Value *> reductionVariableMap;
|
||||
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
|
||||
reductionDecls, privateReductionVariables,
|
||||
reductionVariableMap);
|
||||
{
|
||||
llvm::IRBuilderBase::InsertPointGuard guard(builder);
|
||||
builder.restoreIP(allocaIP);
|
||||
auto args = opInst.getRegion().getArguments();
|
||||
|
||||
for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
|
||||
llvm::Value *var = builder.CreateAlloca(
|
||||
moduleTranslation.convertType(reductionDecls[i].getType()));
|
||||
moduleTranslation.mapValue(args[i], var);
|
||||
privateReductionVariables.push_back(var);
|
||||
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
|
||||
}
|
||||
}
|
||||
|
||||
// Store the mapping between reduction variables and their private copies on
|
||||
// ModuleTranslation stack. It can be then recovered when translating
|
||||
|
@ -640,11 +640,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
|
||||
func.func @parallel_reduction() {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr)
|
||||
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
|
||||
// CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr)
|
||||
omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}}
|
||||
omp.reduction %1, %0 : f32, !llvm.ptr
|
||||
%2 = llvm.load %prv : !llvm.ptr -> f32
|
||||
// CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
|
||||
%3 = llvm.fadd %1, %2 : f32
|
||||
llvm.store %3, %prv : f32, !llvm.ptr
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
@ -654,13 +656,14 @@ func.func @parallel_reduction() {
|
||||
func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
|
||||
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
|
||||
// CHECK: omp.parallel reduction(@add_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
|
||||
omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
|
||||
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
|
||||
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
|
||||
omp.reduction %1, %0 : f32, !llvm.ptr
|
||||
%2 = llvm.load %prv : !llvm.ptr -> f32
|
||||
// CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
|
||||
llvm.fadd %1, %2 : f32
|
||||
// CHECK: omp.yield
|
||||
omp.yield
|
||||
}
|
||||
@ -799,11 +802,14 @@ func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
|
||||
// CHECK-LABEL: func @parallel_reduction2
|
||||
func.func @parallel_reduction2() {
|
||||
%0 = memref.alloca() : memref<1xf32>
|
||||
// CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
|
||||
omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
|
||||
// CHECK: omp.parallel reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>)
|
||||
omp.parallel reduction(@add2_f32 %0 -> %prv : memref<1xf32>) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction
|
||||
omp.reduction %1, %0 : f32, memref<1xf32>
|
||||
%2 = arith.constant 0 : index
|
||||
%3 = memref.load %prv[%2] : memref<1xf32>
|
||||
// CHECK: llvm.fadd
|
||||
%4 = llvm.fadd %1, %3 : f32
|
||||
memref.store %4, %prv[%2] : memref<1xf32>
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
@ -813,13 +819,14 @@ func.func @parallel_reduction2() {
|
||||
func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) {
|
||||
omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) {
|
||||
// CHECK: omp.parallel reduction(@add2_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
|
||||
omp.parallel reduction(@add2_f32 %0 -> %prv : !llvm.ptr) {
|
||||
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
|
||||
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
|
||||
omp.reduction %1, %0 : f32, !llvm.ptr
|
||||
%2 = llvm.load %prv : !llvm.ptr -> f32
|
||||
// CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
|
||||
%3 = llvm.fadd %1, %2 : f32
|
||||
// CHECK: omp.yield
|
||||
omp.yield
|
||||
}
|
||||
|
@ -441,9 +441,11 @@ atomic {
|
||||
llvm.func @simple_reduction_parallel() {
|
||||
%c1 = llvm.mlir.constant(1 : i32) : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
|
||||
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
|
||||
omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
|
||||
%1 = llvm.mlir.constant(2.0 : f32) : f32
|
||||
omp.reduction %1, %0 : f32, !llvm.ptr
|
||||
%2 = llvm.load %prv : !llvm.ptr -> f32
|
||||
%3 = llvm.fadd %2, %1 : f32
|
||||
llvm.store %3, %prv : f32, !llvm.ptr
|
||||
omp.terminator
|
||||
}
|
||||
llvm.return
|
||||
@ -512,10 +514,12 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
|
||||
%lb = llvm.mlir.constant(1 : i64) : i64
|
||||
%step = llvm.mlir.constant(1 : i64) : i64
|
||||
|
||||
omp.parallel reduction(@add_i32 -> %0 : !llvm.ptr) {
|
||||
omp.parallel reduction(@add_i32 %0 -> %prv : !llvm.ptr) {
|
||||
omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
|
||||
%ival = llvm.trunc %iv : i64 to i32
|
||||
omp.reduction %ival, %0 : i32, !llvm.ptr
|
||||
%lprv = llvm.load %prv : !llvm.ptr -> i32
|
||||
%add = llvm.add %lprv, %ival : i32
|
||||
llvm.store %add, %prv : i32, !llvm.ptr
|
||||
omp.yield
|
||||
}
|
||||
omp.terminator
|
||||
|
Loading…
x
Reference in New Issue
Block a user