[flang][cuda] Lower match_any_sync functions to nvvm intrinsics (#127942)

This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-02-20 13:45:15 -08:00 committed by GitHub
parent d1dde17ab0
commit 84c8848f81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 88 additions and 0 deletions

View File

@ -336,6 +336,7 @@ struct IntrinsicLibrary {
template <typename Shift>
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMatmulTranspose(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);

View File

@ -485,6 +485,22 @@ static constexpr IntrinsicHandler handlers[]{
&I::genMatchAllSync,
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
/*isElemental=*/false},
{"match_any_syncjd",
&I::genMatchAnySync,
{{{"mask", asValue}, {"value", asValue}}},
/*isElemental=*/false},
{"match_any_syncjf",
&I::genMatchAnySync,
{{{"mask", asValue}, {"value", asValue}}},
/*isElemental=*/false},
{"match_any_syncjj",
&I::genMatchAnySync,
{{{"mask", asValue}, {"value", asValue}}},
/*isElemental=*/false},
{"match_any_syncjx",
&I::genMatchAnySync,
{{{"mask", asValue}, {"value", asValue}}},
/*isElemental=*/false},
{"matmul",
&I::genMatmul,
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@ -6060,6 +6076,7 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
return result;
}
// MATCH_ALL_SYNC
mlir::Value
IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
@ -6096,6 +6113,32 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
return value;
}
// MATCH_ANY_SYNC
mlir::Value
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
llvm::StringRef funcName =
is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
mlir::MLIRContext *context = builder.getContext();
mlir::Type i32Ty = builder.getI32Type();
mlir::Type i64Ty = builder.getI64Type();
mlir::Type valTy = is32 ? i32Ty : i64Ty;
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
auto funcOp = builder.createFunction(loc, funcName, ftype);
llvm::SmallVector<mlir::Value> filteredArgs;
filteredArgs.push_back(args[0]);
if (args[1].getType().isF32() || args[1].getType().isF64())
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
else
filteredArgs.push_back(args[1]);
return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
}
// MATMUL
fir::ExtendedValue
IntrinsicLibrary::genMatmul(mlir::Type resultType,

View File

@ -589,4 +589,27 @@ interface match_all_sync
end function
end interface
interface match_any_sync
attributes(device) integer function match_any_syncjj(mask, val)
!dir$ ignore_tkr(d) mask, (d) val
integer(4), value :: mask
integer(4), value :: val
end function
attributes(device) integer function match_any_syncjx(mask, val)
!dir$ ignore_tkr(d) mask, (d) val
integer(4), value :: mask
integer(8), value :: val
end function
attributes(device) integer function match_any_syncjf(mask, val)
!dir$ ignore_tkr(d) mask, (d) val
integer(4), value :: mask
real(4), value :: val
end function
attributes(device) integer function match_any_syncjd(mask, val)
!dir$ ignore_tkr(d) mask, (d) val
integer(4), value :: mask
real(8), value :: val
end function
end interface
end module

View File

@ -131,6 +131,25 @@ end subroutine
! CHECK: fir.convert %{{.*}} : (f64) -> i64
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
attributes(device) subroutine testMatchAny()
integer :: a, mask, v32
integer(8) :: v64
real(4) :: r4
real(8) :: r8
a = match_any_sync(mask, v32)
a = match_any_sync(mask, v64)
a = match_any_sync(mask, r4)
a = match_any_sync(mask, r8)
end subroutine
! CHECK-LABEL: func.func @_QPtestmatchany()
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
! CHECK: fir.convert %{{.*}} : (f32) -> i32
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
! CHECK: fir.convert %{{.*}} : (f64) -> i64
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
! CHECK: func.func private @llvm.nvvm.barrier0()
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
! CHECK: func.func private @llvm.nvvm.membar.gl()
@ -141,3 +160,5 @@ end subroutine
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32