[mlir][flang][openmp] Rework parallel reduction operations (#79308)

This patch reworks the way that parallel reduction operations function
to better match the expected semantics from the OpenMP specification.
Previously specific omp.reduction operations were used inside the
region, meaning that the reduction only applied when the correct
operation was used, whereas the specification states that any change to
the variable inside the region should be taken into account for the
reduction.

The new semantics create a private reduction variable as a block
argument which should be used normally for all operations on that
variable in the region; this private variable is then combined with the
others into the shared variable. This way no special omp.reduction
operations are needed inside the region.

This patch only makes the change for the `parallel` operation, the
change for the `wsloop` operation will be in a separate patch.

---------

Co-authored-by: Kiran Chandramohan <kiran.chandramohan@arm.com>
This commit is contained in:
David Truby 2024-02-12 17:19:49 +00:00 committed by GitHub
parent 1114ac4399
commit 9ecf4d20bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 249 additions and 63 deletions

View File

@ -621,10 +621,12 @@ public:
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*mapSymbols = nullptr) const;
bool processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
bool
processReduction(mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr) const;
bool processSectionsReduction(mlir::Location currentLocation) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
@ -1079,12 +1081,14 @@ public:
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
static void addReductionDecl(
mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpReductionClause &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
static void
addReductionDecl(mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpReductionClause &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
@ -1114,6 +1118,8 @@ public:
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@ -1148,6 +1154,8 @@ public:
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@ -1948,13 +1956,16 @@ bool ClauseProcessor::processMap(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
const {
return findRepeatableClause<ClauseTy::Reduction>(
[&](const ClauseTy::Reduction *reductionClause,
const Fortran::parser::CharBlock &) {
ReductionProcessor rp;
rp.addReductionDecl(currentLocation, converter, reductionClause->v,
reductionVars, reductionDeclSymbols);
reductionVars, reductionDeclSymbols,
reductionSymbols);
});
}
@ -2304,6 +2315,14 @@ struct OpWithBodyGenInfo {
return *this;
}
OpWithBodyGenInfo &
setReductions(llvm::SmallVector<const Fortran::semantics::Symbol *> *value1,
llvm::SmallVector<mlir::Type> *value2) {
reductionSymbols = value1;
reductionTypes = value2;
return *this;
}
OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) {
genRegionEntryCB = value;
return *this;
@ -2323,6 +2342,11 @@ struct OpWithBodyGenInfo {
const Fortran::parser::OmpClauseList *clauses = nullptr;
/// [in] if provided, processes the construct's data-sharing attributes.
DataSharingProcessor *dsp = nullptr;
/// [in] if provided, list of reduction symbols
llvm::SmallVector<const Fortran::semantics::Symbol *> *reductionSymbols =
nullptr;
/// [in] if provided, list of reduction types
llvm::SmallVector<mlir::Type> *reductionTypes = nullptr;
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
/// is created in the region.
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
@ -2567,6 +2591,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
ClauseProcessor cp(converter, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@ -2576,13 +2601,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
cp.processDefault();
cp.processAllocate(allocatorOperands, allocateOperands);
if (!outerCombined)
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
&reductionSymbols);
llvm::SmallVector<mlir::Type> reductionTypes;
reductionTypes.reserve(reductionVars.size());
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
[](mlir::Value v) { return v.getType(); });
auto reductionCallback = [&](mlir::Operation *op) {
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
currentLocation);
auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
reductionTypes, locs);
for (auto [arg, prv] :
llvm::zip_equal(reductionSymbols, block->getArguments())) {
converter.bindSymbol(*arg, prv);
}
return reductionSymbols;
};
return genOpWithBody<mlir::omp::ParallelOp>(
OpWithBodyGenInfo(converter, currentLocation, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList),
.setClauses(&clauseList)
.setReductions(&reductionSymbols, &reductionTypes)
.setGenRegionEntryCb(reductionCallback),
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
@ -3634,10 +3679,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
}
if (singleDirective) {
genOpenMPReduction(converter, beginClauseList);
if (singleDirective)
return;
}
// Codegen for combined directives
bool combinedDirective = false;
@ -3673,7 +3716,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
")");
genNestedEvaluations(converter, eval);
genOpenMPReduction(converter, beginClauseList);
}
static void

View File

@ -27,9 +27,11 @@
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@ -48,9 +50,11 @@ end subroutine
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@ -72,11 +76,15 @@ end subroutine
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
!CHECK: fir.store %[[RES1]] to %[[PRV1]]
!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
!CHECK: fir.store %[[RES0]] to %[[PRV0]]
!CHECK: omp.terminator
!CHECK: }
!CHECK: return

View File

@ -28,9 +28,12 @@
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@ -50,9 +53,12 @@ end subroutine
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@ -76,11 +82,17 @@ end subroutine
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return

View File

@ -0,0 +1,38 @@
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
!CHECK: omp.reduction.declare @[[REDUCTION_DECLARE:[_a-z0-9]+]] : i32 init {
!CHECK: ^bb0(%{{.*}}: i32):
!CHECK: %[[I0:[_a-z0-9]+]] = arith.constant 0 : i32
!CHECK: omp.yield(%[[I0]] : i32)
!CHECK: } combiner {
!CHECK: ^bb0(%[[C0:[_a-z0-9]+]]: i32, %[[C1:[_a-z0-9]+]]: i32):
!CHECK: %[[CR:[_a-z0-9]+]] = arith.addi %[[C0]], %[[C1]] : i32
!CHECK: omp.yield(%[[CR]] : i32)
!CHECK: }
!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "mn"} {
!CHECK: %[[RED_ACCUM_REF:[_a-z0-9]+]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: %[[RED_ACCUM_VAL:[_a-z0-9]+]] = fir.load %[[RED_ACCUM_DECL]]#0 : !fir.ref<i32>
!CHECK: {{.*}} = fir.call @_FortranAioOutputInteger32(%{{.*}}, %[[RED_ACCUM_VAL]]) fastmath<contract> : (!fir.ref<i8>, i32) -> i1
!CHECK: return
!CHECK: }
program mn
integer :: i
i = 0
!$omp parallel reduction(+:i)
i = 1
!$omp end parallel
print *, i
end program

View File

@ -191,11 +191,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
unsigned getNumReductionVars() { return getReductionVars().size(); }
}];
let assemblyFormat = [{
oilist( `reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
) `)`
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
oilist(
`if` `(` $if_expr_var `:` type($if_expr_var) `)`
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
@ -203,7 +200,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
) $region attr-dict
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
}];
let hasVerifier = 1;
}

View File

@ -21,6 +21,7 @@
#include "mlir/Interfaces/FoldInterfaces.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
@ -34,6 +35,7 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
#include "mlir/Support/LogicalResult.h"
using namespace mlir;
using namespace mlir::omp;
@ -427,6 +429,71 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
ParseResult
parseReductionClause(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
SmallVectorImpl<OpAsmParser::Argument> &privates) {
if (failed(parser.parseOptionalKeyword("reduction")))
return failure();
SmallVector<SymbolRefAttr> reductionVec;
if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseArrow() ||
parser.parseArgument(privates.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
prv.type = type;
}
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
}
static void printReductionClause(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
p << "reduction(";
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
region.front().getArguments(), types),
p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : "
<< type;
});
p << ") ";
}
static ParseResult
parseParallelRegion(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parseReductionClause(parser, region, operands, types,
reductionSymbols, privates)))
return parser.parseRegion(region, privates);
return parser.parseRegion(region);
}
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
if (reductionSymbols)
printReductionClause(p, op, region, operands, types, reductionSymbols);
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
/// reduction-entry-list ::= reduction-entry
/// | reduction-entry-list `,` reduction-entry
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
@ -1114,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region &region,
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
for (auto &iv : ivs)
iv.type = loopVarType;
return parser.parseRegion(region, ivs);
}

View File

@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// Allocate reduction vars
SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables,
reductionVariableMap);
{
llvm::IRBuilderBase::InsertPointGuard guard(builder);
builder.restoreIP(allocaIP);
auto args = opInst.getRegion().getArguments();
for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
llvm::Value *var = builder.CreateAlloca(
moduleTranslation.convertType(reductionDecls[i].getType()));
moduleTranslation.mapValue(args[i], var);
privateReductionVariables.push_back(var);
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
}
}
// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating

View File

@ -640,11 +640,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
func.func @parallel_reduction() {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
// CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr)
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
// CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr)
omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
%1 = arith.constant 2.0 : f32
// CHECK: omp.reduction %{{.+}}, %{{.+}}
omp.reduction %1, %0 : f32, !llvm.ptr
%2 = llvm.load %prv : !llvm.ptr -> f32
// CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
%3 = llvm.fadd %1, %2 : f32
llvm.store %3, %prv : f32, !llvm.ptr
omp.terminator
}
return
@ -654,13 +656,14 @@ func.func @parallel_reduction() {
func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
// CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
// CHECK: omp.parallel reduction(@add_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
%1 = arith.constant 2.0 : f32
// CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
omp.reduction %1, %0 : f32, !llvm.ptr
%2 = llvm.load %prv : !llvm.ptr -> f32
// CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
llvm.fadd %1, %2 : f32
// CHECK: omp.yield
omp.yield
}
@ -799,11 +802,14 @@ func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
// CHECK-LABEL: func @parallel_reduction2
func.func @parallel_reduction2() {
%0 = memref.alloca() : memref<1xf32>
// CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
// CHECK: omp.parallel reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>)
omp.parallel reduction(@add2_f32 %0 -> %prv : memref<1xf32>) {
%1 = arith.constant 2.0 : f32
// CHECK: omp.reduction
omp.reduction %1, %0 : f32, memref<1xf32>
%2 = arith.constant 0 : index
%3 = memref.load %prv[%2] : memref<1xf32>
// CHECK: llvm.fadd
%4 = llvm.fadd %1, %3 : f32
memref.store %4, %prv[%2] : memref<1xf32>
omp.terminator
}
return
@ -813,13 +819,14 @@ func.func @parallel_reduction2() {
func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
// CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) {
omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) {
// CHECK: omp.parallel reduction(@add2_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
omp.parallel reduction(@add2_f32 %0 -> %prv : !llvm.ptr) {
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
%1 = arith.constant 2.0 : f32
// CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
omp.reduction %1, %0 : f32, !llvm.ptr
%2 = llvm.load %prv : !llvm.ptr -> f32
// CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
%3 = llvm.fadd %1, %2 : f32
// CHECK: omp.yield
omp.yield
}

View File

@ -441,9 +441,11 @@ atomic {
llvm.func @simple_reduction_parallel() {
%c1 = llvm.mlir.constant(1 : i32) : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
%1 = llvm.mlir.constant(2.0 : f32) : f32
omp.reduction %1, %0 : f32, !llvm.ptr
%2 = llvm.load %prv : !llvm.ptr -> f32
%3 = llvm.fadd %2, %1 : f32
llvm.store %3, %prv : f32, !llvm.ptr
omp.terminator
}
llvm.return
@ -512,10 +514,12 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
%lb = llvm.mlir.constant(1 : i64) : i64
%step = llvm.mlir.constant(1 : i64) : i64
omp.parallel reduction(@add_i32 -> %0 : !llvm.ptr) {
omp.parallel reduction(@add_i32 %0 -> %prv : !llvm.ptr) {
omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
%ival = llvm.trunc %iv : i64 to i32
omp.reduction %ival, %0 : i32, !llvm.ptr
%lprv = llvm.load %prv : !llvm.ptr -> i32
%add = llvm.add %lprv, %ival : i32
llvm.store %add, %prv : i32, !llvm.ptr
omp.yield
}
omp.terminator