From 1d8966e2465364dca8aea71d9331d46a0e908265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 10 Apr 2025 17:15:23 -0700 Subject: [PATCH] [flang][cuda] Use the provided stream in kernel launch (#135267) --- flang-rt/lib/cuda/kernel.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/flang-rt/lib/cuda/kernel.cpp b/flang-rt/lib/cuda/kernel.cpp index 6b60b72630a1..73b4e24bf701 100644 --- a/flang-rt/lib/cuda/kernel.cpp +++ b/flang-rt/lib/cuda/kernel.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "flang/Runtime/CUDA/kernel.h" +#include "flang-rt/runtime/descriptor.h" #include "flang-rt/runtime/terminator.h" #include "flang/Runtime/CUDA/common.h" @@ -74,9 +75,9 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY, Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; terminator.Crash("Too many invalid grid dimensions"); } - cudaStream_t cuStream = 0; // TODO stream managment - CUDA_REPORT_IF_ERROR( - cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, cuStream)); + cudaStream_t defaultStream = 0; + CUDA_REPORT_IF_ERROR(cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, + stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream)); } void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX, @@ -140,7 +141,11 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX, terminator.Crash("Too many invalid grid dimensions"); } config.dynamicSmemBytes = smem; - config.stream = 0; // TODO stream managment + if (stream != kNoAsyncId) { + config.stream = (cudaStream_t)stream; + } else { + config.stream = 0; + } cudaLaunchAttribute launchAttr[1]; launchAttr[0].id = cudaLaunchAttributeClusterDimension; launchAttr[0].val.clusterDim.x = clusterX; @@ -212,9 +217,10 @@ void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX, Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; terminator.Crash("Too many invalid grid dimensions"); } - cudaStream_t cuStream = 0; // TODO stream managment - CUDA_REPORT_IF_ERROR(cudaLaunchCooperativeKernel( - kernel, gridDim, blockDim, params, smem, cuStream)); + cudaStream_t defaultStream = 0; + CUDA_REPORT_IF_ERROR( + cudaLaunchCooperativeKernel(kernel, gridDim, blockDim, params, smem, + stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream)); } } // extern "C"