mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-18 19:16:43 +00:00
[flang][cuda] Lower match_all_sync functions to nvvm intrinsics (#127940)
This commit is contained in:
parent
02e8fd7a30
commit
726c4b9f77
@ -335,6 +335,7 @@ struct IntrinsicLibrary {
|
||||
mlir::Value genMalloc(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
template <typename Shift>
|
||||
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
|
||||
fir::ExtendedValue genMatmulTranspose(mlir::Type,
|
||||
llvm::ArrayRef<fir::ExtendedValue>);
|
||||
|
@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
|
||||
(*details->cudaDataAttr() == common::CUDADataAttr::Device ||
|
||||
*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
|
||||
*details->cudaDataAttr() == common::CUDADataAttr::Unified ||
|
||||
*details->cudaDataAttr() == common::CUDADataAttr::Shared ||
|
||||
*details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
|
||||
return true;
|
||||
}
|
||||
|
@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{
|
||||
{"malloc", &I::genMalloc},
|
||||
{"maskl", &I::genMask<mlir::arith::ShLIOp>},
|
||||
{"maskr", &I::genMask<mlir::arith::ShRUIOp>},
|
||||
{"match_all_syncjd",
|
||||
&I::genMatchAllSync,
|
||||
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_all_syncjf",
|
||||
&I::genMatchAllSync,
|
||||
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_all_syncjj",
|
||||
&I::genMatchAllSync,
|
||||
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_all_syncjx",
|
||||
&I::genMatchAllSync,
|
||||
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
|
||||
/*isElemental=*/false},
|
||||
{"matmul",
|
||||
&I::genMatmul,
|
||||
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
|
||||
@ -6044,6 +6060,42 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
|
||||
return result;
|
||||
}
|
||||
|
||||
mlir::Value
|
||||
IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 3);
|
||||
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
|
||||
|
||||
llvm::StringRef funcName =
|
||||
is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
|
||||
mlir::MLIRContext *context = builder.getContext();
|
||||
mlir::Type i32Ty = builder.getI32Type();
|
||||
mlir::Type i64Ty = builder.getI64Type();
|
||||
mlir::Type i1Ty = builder.getI1Type();
|
||||
mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
|
||||
mlir::Type valTy = is32 ? i32Ty : i64Ty;
|
||||
|
||||
mlir::FunctionType ftype =
|
||||
mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
|
||||
auto funcOp = builder.createFunction(loc, funcName, ftype);
|
||||
llvm::SmallVector<mlir::Value> filteredArgs;
|
||||
filteredArgs.push_back(args[0]);
|
||||
if (args[1].getType().isF32() || args[1].getType().isF64())
|
||||
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
|
||||
else
|
||||
filteredArgs.push_back(args[1]);
|
||||
auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
|
||||
auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
|
||||
auto value = builder.create<fir::ExtractValueOp>(
|
||||
loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
|
||||
auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
|
||||
auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
|
||||
builder.getArrayAttr(one));
|
||||
auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
|
||||
builder.create<fir::StoreOp>(loc, conv, args[2]);
|
||||
return value;
|
||||
}
|
||||
|
||||
// MATMUL
|
||||
fir::ExtendedValue
|
||||
IntrinsicLibrary::genMatmul(mlir::Type resultType,
|
||||
|
@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
|
||||
rewriter.setInsertionPointAfter(size.getDefiningOp());
|
||||
}
|
||||
|
||||
if (auto dataAttr = alloc->getAttrOfType<cuf::DataAttributeAttr>(
|
||||
cuf::getDataAttrName())) {
|
||||
if (dataAttr.getValue() == cuf::DataAttribute::Shared)
|
||||
allocaAs = 3;
|
||||
}
|
||||
|
||||
// NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
|
||||
// pointers! Only propagate pinned and bindc_name to help debugging, but
|
||||
// this should have no functional purpose (and passing the operand segment
|
||||
@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
|
||||
rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
|
||||
alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc);
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) {
|
||||
if (op.getDataAttr() == cuf::DataAttribute::Device ||
|
||||
op.getDataAttr() == cuf::DataAttribute::Managed ||
|
||||
op.getDataAttr() == cuf::DataAttribute::Unified ||
|
||||
op.getDataAttr() == cuf::DataAttribute::Pinned)
|
||||
op.getDataAttr() == cuf::DataAttribute::Pinned ||
|
||||
op.getDataAttr() == cuf::DataAttribute::Shared)
|
||||
return mlir::success();
|
||||
return op.emitOpError()
|
||||
<< "expect device, managed, pinned or unified cuda attribute";
|
||||
|
@ -562,4 +562,31 @@ implicit none
|
||||
end function
|
||||
end interface
|
||||
|
||||
interface match_all_sync
|
||||
attributes(device) integer function match_all_syncjj(mask, val, pred)
|
||||
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
|
||||
integer(4), value :: mask
|
||||
integer(4), value :: val
|
||||
integer(4) :: pred
|
||||
end function
|
||||
attributes(device) integer function match_all_syncjx(mask, val, pred)
|
||||
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
|
||||
integer(4), value :: mask
|
||||
integer(8), value :: val
|
||||
integer(4) :: pred
|
||||
end function
|
||||
attributes(device) integer function match_all_syncjf(mask, val, pred)
|
||||
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
|
||||
integer(4), value :: mask
|
||||
real(4), value :: val
|
||||
integer(4) :: pred
|
||||
end function
|
||||
attributes(device) integer function match_all_syncjd(mask, val, pred)
|
||||
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
|
||||
integer(4), value :: mask
|
||||
real(8), value :: val
|
||||
integer(4) :: pred
|
||||
end function
|
||||
end interface
|
||||
|
||||
end module
|
||||
|
@ -112,6 +112,25 @@ end
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
|
||||
attributes(device) subroutine testMatch()
|
||||
integer :: a, ipred, mask, v32
|
||||
integer(8) :: v64
|
||||
real(4) :: r4
|
||||
real(8) :: r8
|
||||
a = match_all_sync(mask, v32, ipred)
|
||||
a = match_all_sync(mask, v64, ipred)
|
||||
a = match_all_sync(mask, r4, ipred)
|
||||
a = match_all_sync(mask, r8, ipred)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPtestmatch()
|
||||
! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
|
||||
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
|
||||
! CHECK: fir.convert %{{.*}} : (f32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
|
||||
! CHECK: fir.convert %{{.*}} : (f64) -> i64
|
||||
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
|
||||
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0()
|
||||
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
|
||||
! CHECK: func.func private @llvm.nvvm.membar.gl()
|
||||
@ -120,3 +139,5 @@ end
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
|
||||
! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
|
||||
|
Loading…
x
Reference in New Issue
Block a user