diff --git a/flang-rt/lib/cuda/kernel.cpp b/flang-rt/lib/cuda/kernel.cpp index 6b60b72630a1..73b4e24bf701 100644 --- a/flang-rt/lib/cuda/kernel.cpp +++ b/flang-rt/lib/cuda/kernel.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "flang/Runtime/CUDA/kernel.h" +#include "flang-rt/runtime/descriptor.h" #include "flang-rt/runtime/terminator.h" #include "flang/Runtime/CUDA/common.h" @@ -74,9 +75,9 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY, Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; terminator.Crash("Too many invalid grid dimensions"); } - cudaStream_t cuStream = 0; // TODO stream managment - CUDA_REPORT_IF_ERROR( - cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, cuStream)); + cudaStream_t defaultStream = 0; + CUDA_REPORT_IF_ERROR(cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, + stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream)); } void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX, @@ -140,7 +141,11 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX, terminator.Crash("Too many invalid grid dimensions"); } config.dynamicSmemBytes = smem; - config.stream = 0; // TODO stream managment + if (stream != kNoAsyncId) { + config.stream = (cudaStream_t)stream; + } else { + config.stream = 0; + } cudaLaunchAttribute launchAttr[1]; launchAttr[0].id = cudaLaunchAttributeClusterDimension; launchAttr[0].val.clusterDim.x = clusterX; @@ -212,9 +217,10 @@ void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX, Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; terminator.Crash("Too many invalid grid dimensions"); } - cudaStream_t cuStream = 0; // TODO stream managment - CUDA_REPORT_IF_ERROR(cudaLaunchCooperativeKernel( - kernel, gridDim, blockDim, params, smem, cuStream)); + cudaStream_t defaultStream = 0; + CUDA_REPORT_IF_ERROR( + cudaLaunchCooperativeKernel(kernel, gridDim, blockDim, params, smem, + stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream)); } } // extern "C"