mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-17 08:06:40 +00:00
[flang][cuda] Lower match_any_sync functions to nvvm intrinsics (#127942)
This commit is contained in:
parent
d1dde17ab0
commit
84c8848f81
@ -336,6 +336,7 @@ struct IntrinsicLibrary {
|
||||
template <typename Shift>
|
||||
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
|
||||
fir::ExtendedValue genMatmulTranspose(mlir::Type,
|
||||
llvm::ArrayRef<fir::ExtendedValue>);
|
||||
|
@ -485,6 +485,22 @@ static constexpr IntrinsicHandler handlers[]{
|
||||
&I::genMatchAllSync,
|
||||
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_any_syncjd",
|
||||
&I::genMatchAnySync,
|
||||
{{{"mask", asValue}, {"value", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_any_syncjf",
|
||||
&I::genMatchAnySync,
|
||||
{{{"mask", asValue}, {"value", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_any_syncjj",
|
||||
&I::genMatchAnySync,
|
||||
{{{"mask", asValue}, {"value", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"match_any_syncjx",
|
||||
&I::genMatchAnySync,
|
||||
{{{"mask", asValue}, {"value", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"matmul",
|
||||
&I::genMatmul,
|
||||
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
|
||||
@ -6060,6 +6076,7 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
|
||||
return result;
|
||||
}
|
||||
|
||||
// MATCH_ALL_SYNC
|
||||
mlir::Value
|
||||
IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
@ -6096,6 +6113,32 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
|
||||
return value;
|
||||
}
|
||||
|
||||
// MATCH_ANY_SYNC
|
||||
mlir::Value
|
||||
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 2);
|
||||
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
|
||||
|
||||
llvm::StringRef funcName =
|
||||
is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
|
||||
mlir::MLIRContext *context = builder.getContext();
|
||||
mlir::Type i32Ty = builder.getI32Type();
|
||||
mlir::Type i64Ty = builder.getI64Type();
|
||||
mlir::Type valTy = is32 ? i32Ty : i64Ty;
|
||||
|
||||
mlir::FunctionType ftype =
|
||||
mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
|
||||
auto funcOp = builder.createFunction(loc, funcName, ftype);
|
||||
llvm::SmallVector<mlir::Value> filteredArgs;
|
||||
filteredArgs.push_back(args[0]);
|
||||
if (args[1].getType().isF32() || args[1].getType().isF64())
|
||||
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
|
||||
else
|
||||
filteredArgs.push_back(args[1]);
|
||||
return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
|
||||
}
|
||||
|
||||
// MATMUL
|
||||
fir::ExtendedValue
|
||||
IntrinsicLibrary::genMatmul(mlir::Type resultType,
|
||||
|
@ -589,4 +589,27 @@ interface match_all_sync
|
||||
end function
|
||||
end interface
|
||||
|
||||
interface match_any_sync
|
||||
attributes(device) integer function match_any_syncjj(mask, val)
|
||||
!dir$ ignore_tkr(d) mask, (d) val
|
||||
integer(4), value :: mask
|
||||
integer(4), value :: val
|
||||
end function
|
||||
attributes(device) integer function match_any_syncjx(mask, val)
|
||||
!dir$ ignore_tkr(d) mask, (d) val
|
||||
integer(4), value :: mask
|
||||
integer(8), value :: val
|
||||
end function
|
||||
attributes(device) integer function match_any_syncjf(mask, val)
|
||||
!dir$ ignore_tkr(d) mask, (d) val
|
||||
integer(4), value :: mask
|
||||
real(4), value :: val
|
||||
end function
|
||||
attributes(device) integer function match_any_syncjd(mask, val)
|
||||
!dir$ ignore_tkr(d) mask, (d) val
|
||||
integer(4), value :: mask
|
||||
real(8), value :: val
|
||||
end function
|
||||
end interface
|
||||
|
||||
end module
|
||||
|
@ -131,6 +131,25 @@ end subroutine
|
||||
! CHECK: fir.convert %{{.*}} : (f64) -> i64
|
||||
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
|
||||
|
||||
attributes(device) subroutine testMatchAny()
|
||||
integer :: a, mask, v32
|
||||
integer(8) :: v64
|
||||
real(4) :: r4
|
||||
real(8) :: r8
|
||||
a = match_any_sync(mask, v32)
|
||||
a = match_any_sync(mask, v64)
|
||||
a = match_any_sync(mask, r4)
|
||||
a = match_any_sync(mask, r8)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPtestmatchany()
|
||||
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
|
||||
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
|
||||
! CHECK: fir.convert %{{.*}} : (f32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
|
||||
! CHECK: fir.convert %{{.*}} : (f64) -> i64
|
||||
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
|
||||
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0()
|
||||
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
|
||||
! CHECK: func.func private @llvm.nvvm.membar.gl()
|
||||
@ -141,3 +160,5 @@ end subroutine
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
|
||||
! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
|
||||
! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32
|
||||
|
Loading…
x
Reference in New Issue
Block a user