[flang][cuda] Use the provided stream in kernel launch (#135267)

This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-04-10 17:15:23 -07:00 committed by GitHub
parent 1cd59264aa
commit 1d8966e246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Runtime/CUDA/kernel.h"
#include "flang-rt/runtime/descriptor.h"
#include "flang-rt/runtime/terminator.h"
#include "flang/Runtime/CUDA/common.h"
@ -74,9 +75,9 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
terminator.Crash("Too many invalid grid dimensions");
}
cudaStream_t cuStream = 0; // TODO stream managment
CUDA_REPORT_IF_ERROR(
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, cuStream));
cudaStream_t defaultStream = 0;
CUDA_REPORT_IF_ERROR(cudaLaunchKernel(kernel, gridDim, blockDim, params, smem,
stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream));
}
void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
@ -140,7 +141,11 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
terminator.Crash("Too many invalid grid dimensions");
}
config.dynamicSmemBytes = smem;
config.stream = 0; // TODO stream managment
if (stream != kNoAsyncId) {
config.stream = (cudaStream_t)stream;
} else {
config.stream = 0;
}
cudaLaunchAttribute launchAttr[1];
launchAttr[0].id = cudaLaunchAttributeClusterDimension;
launchAttr[0].val.clusterDim.x = clusterX;
@ -212,9 +217,10 @@ void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX,
Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
terminator.Crash("Too many invalid grid dimensions");
}
cudaStream_t cuStream = 0; // TODO stream managment
CUDA_REPORT_IF_ERROR(cudaLaunchCooperativeKernel(
kernel, gridDim, blockDim, params, smem, cuStream));
cudaStream_t defaultStream = 0;
CUDA_REPORT_IF_ERROR(
cudaLaunchCooperativeKernel(kernel, gridDim, blockDim, params, smem,
stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream));
}
} // extern "C"