From 726c4b9f77862d83b6e5e16c8d5a2fc4fb1589a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 20 Feb 2025 09:10:25 -0800 Subject: [PATCH] [flang][cuda] Lower match_all_sync functions to nvvm intrinsics (#127940) --- .../flang/Optimizer/Builder/IntrinsicCall.h | 1 + flang/include/flang/Semantics/tools.h | 1 + flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 52 +++++++++++++++++++ flang/lib/Optimizer/CodeGen/CodeGen.cpp | 7 +++ flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp | 3 +- flang/module/cudadevice.f90 | 27 ++++++++++ flang/test/Lower/CUDA/cuda-device-proc.cuf | 21 ++++++++ 7 files changed, 111 insertions(+), 1 deletion(-) diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index 65732ce7f322..caec6a913293 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -335,6 +335,7 @@ struct IntrinsicLibrary { mlir::Value genMalloc(mlir::Type, llvm::ArrayRef); template mlir::Value genMask(mlir::Type, llvm::ArrayRef); + mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMatmulTranspose(mlir::Type, llvm::ArrayRef); diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h index e82446a2ba88..56dcfa88ad92 100644 --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) { (*details->cudaDataAttr() == common::CUDADataAttr::Device || *details->cudaDataAttr() == common::CUDADataAttr::Managed || *details->cudaDataAttr() == common::CUDADataAttr::Unified || + *details->cudaDataAttr() == common::CUDADataAttr::Shared || *details->cudaDataAttr() == common::CUDADataAttr::Pinned)) { return true; } diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 93744fa58ebc..754496921ca3 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{ {"malloc", &I::genMalloc}, {"maskl", &I::genMask}, {"maskr", &I::genMask}, + {"match_all_syncjd", + &I::genMatchAllSync, + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_all_syncjf", + &I::genMatchAllSync, + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_all_syncjj", + &I::genMatchAllSync, + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_all_syncjx", + &I::genMatchAllSync, + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, {"matmul", &I::genMatmul, {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}}, @@ -6044,6 +6060,42 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType, return result; } +mlir::Value +IntrinsicLibrary::genMatchAllSync(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 3); + bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32(); + + llvm::StringRef funcName = + is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p"; + mlir::MLIRContext *context = builder.getContext(); + mlir::Type i32Ty = builder.getI32Type(); + mlir::Type i64Ty = builder.getI64Type(); + mlir::Type i1Ty = builder.getI1Type(); + mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty}); + mlir::Type valTy = is32 ? i32Ty : i64Ty; + + mlir::FunctionType ftype = + mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy}); + auto funcOp = builder.createFunction(loc, funcName, ftype); + llvm::SmallVector filteredArgs; + filteredArgs.push_back(args[0]); + if (args[1].getType().isF32() || args[1].getType().isF64()) + filteredArgs.push_back(builder.create(loc, valTy, args[1])); + else + filteredArgs.push_back(args[1]); + auto call = builder.create(loc, funcOp, filteredArgs); + auto zero = builder.getIntegerAttr(builder.getIndexType(), 0); + auto value = builder.create( + loc, resultType, call.getResult(0), builder.getArrayAttr(zero)); + auto one = builder.getIntegerAttr(builder.getIndexType(), 1); + auto pred = builder.create(loc, i1Ty, call.getResult(0), + builder.getArrayAttr(one)); + auto conv = builder.create(loc, resultType, pred); + builder.create(loc, conv, args[2]); + return value; +} + // MATMUL fir::ExtendedValue IntrinsicLibrary::genMatmul(mlir::Type resultType, diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index c76b7cde55bd..439cc7a85623 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion { rewriter.setInsertionPointAfter(size.getDefiningOp()); } + if (auto dataAttr = alloc->getAttrOfType( + cuf::getDataAttrName())) { + if (dataAttr.getValue() == cuf::DataAttribute::Shared) + allocaAs = 3; + } + // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque // pointers! Only propagate pinned and bindc_name to help debugging, but // this should have no functional purpose (and passing the operand segment @@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion { rewriter.replaceOpWithNewOp( alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc); } + return mlir::success(); } }; diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index b05991a29a32..fa82f3916a57 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) { if (op.getDataAttr() == cuf::DataAttribute::Device || op.getDataAttr() == cuf::DataAttribute::Managed || op.getDataAttr() == cuf::DataAttribute::Unified || - op.getDataAttr() == cuf::DataAttribute::Pinned) + op.getDataAttr() == cuf::DataAttribute::Pinned || + op.getDataAttr() == cuf::DataAttribute::Shared) return mlir::success(); return op.emitOpError() << "expect device, managed, pinned or unified cuda attribute"; diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index e473590a7d78..c75c5c191ab5 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -562,4 +562,31 @@ implicit none end function end interface +interface match_all_sync + attributes(device) integer function match_all_syncjj(mask, val, pred) +!dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + integer(4), value :: val + integer(4) :: pred + end function + attributes(device) integer function match_all_syncjx(mask, val, pred) +!dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + integer(8), value :: val + integer(4) :: pred + end function + attributes(device) integer function match_all_syncjf(mask, val, pred) +!dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + real(4), value :: val + integer(4) :: pred + end function + attributes(device) integer function match_all_syncjd(mask, val, pred) +!dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + real(8), value :: val + integer(4) :: pred + end function +end interface + end module diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index 6a5524102c0e..1210dae8608c 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -112,6 +112,25 @@ end ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath : (i32) -> i32 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath : (i32) -> i32 +attributes(device) subroutine testMatch() + integer :: a, ipred, mask, v32 + integer(8) :: v64 + real(4) :: r4 + real(8) :: r8 + a = match_all_sync(mask, v32, ipred) + a = match_all_sync(mask, v64, ipred) + a = match_all_sync(mask, r4, ipred) + a = match_all_sync(mask, r8, ipred) +end subroutine + +! CHECK-LABEL: func.func @_QPtestmatch() +! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p +! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p +! CHECK: fir.convert %{{.*}} : (f32) -> i32 +! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p +! CHECK: fir.convert %{{.*}} : (f64) -> i64 +! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p + ! CHECK: func.func private @llvm.nvvm.barrier0() ! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32) ! CHECK: func.func private @llvm.nvvm.membar.gl() @@ -120,3 +139,5 @@ end ! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32 ! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32 ! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32 +! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple +! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple