From 16636f9c97414d0c5195c6fd47227756d4754095 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Fri, 2 Feb 2024 10:47:40 -0800 Subject: [PATCH] [jax_triton] Only use side stream to do autotuning when doing graph capture When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition. PiperOrigin-RevId: 603728867 --- jaxlib/gpu/triton_kernels.cc | 51 ++++++++++++++++++++++-------------- jaxlib/gpu/triton_kernels.h | 1 + jaxlib/gpu/vendor.h | 2 ++ 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 3cd2dbcac..88d451a81 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -156,7 +156,7 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, { JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, AutotunedKernelCall::Autotune( - std::move(autotuned_call), buffers)); + std::move(autotuned_call), stream, buffers)); kernel_call = std::make_unique(std::move(kernel_call_)); } } else { @@ -515,19 +515,26 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { } /*static*/ absl::StatusOr AutotunedKernelCall::Autotune( - AutotunedKernelCall kernel_call, void** buffers) { + AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) { // Ensure a valid context for driver calls that don't take the stream. //gpuContext_t context; //GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context)); //GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); //absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; - gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED; - GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); + gpustreamCaptureStatus_t capture_status; + GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status)); + bool is_capturing = capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE; - // Need a side stream so as not to interfere with graph capture. - gpuStream_t stream; - GPU_RETURN_IF_ERROR(gpuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED; + gpuStream_t autotune_stream = stream; + + if (is_capturing) { + GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); + // Need a side stream so as not to interfere with graph capture. + GPU_RETURN_IF_ERROR( + gpuStreamCreate(&autotune_stream, CU_STREAM_NON_BLOCKING)); + } // If an input aliases with an output, it will get overwritten during the // kernel execution. If the kernel is called repeatedly, as we do during @@ -538,8 +545,9 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { if (buffers[input_idx] == buffers[output_idx]) { std::vector input_copy(size); GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync( - input_copy.data(), reinterpret_cast(buffers[input_idx]), - size, stream)); + input_copy.data(), + reinterpret_cast(buffers[input_idx]), size, + autotune_stream)); input_copies[input_idx] = std::move(input_copy); } } @@ -549,8 +557,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { // iterations to run for benchmarking. float best = std::numeric_limits::infinity(); for (Config& config : kernel_call.configs_) { - JAX_ASSIGN_OR_RETURN(float t, - Benchmark(stream, config.kernel_call, buffers, 1)); + JAX_ASSIGN_OR_RETURN( + float t, Benchmark(autotune_stream, config.kernel_call, buffers, 1)); LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms"; best = std::min(best, t); } @@ -566,7 +574,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { } best = std::numeric_limits::infinity(); - JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream)); + JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(autotune_stream)); for (Config& config : kernel_call.configs_) { if (!config.kernel_call.CanLaunchOnDevice(device)) { LOG(WARNING) << "Unable to launch autotune config on device: " @@ -574,8 +582,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { continue; } - JAX_ASSIGN_OR_RETURN( - float t, Benchmark(stream, config.kernel_call, buffers, timed_iters)); + JAX_ASSIGN_OR_RETURN(float t, Benchmark(autotune_stream, config.kernel_call, + buffers, timed_iters)); LOG(INFO) << config.description << ", ran " << timed_iters << " iters in " << t << " ms"; @@ -596,15 +604,18 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { // Restore aliased inputs to their original values. for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) { - GPU_RETURN_IF_ERROR( - gpuMemcpyHtoDAsync(reinterpret_cast(buffers[input_idx]), - input_copies[input_idx].data(), size, stream)); + GPU_RETURN_IF_ERROR(gpuMemcpyHtoDAsync( + reinterpret_cast(buffers[input_idx]), + input_copies[input_idx].data(), size, autotune_stream)); } // Synchronize stream to ensure copies are complete before the host copy // is deleted. - GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream)); - GPU_RETURN_IF_ERROR(gpuStreamDestroy(stream)); - GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); + GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream)); + + if (is_capturing) { + GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream)); + GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); + } return std::move(kernel_call.configs_[0].kernel_call); } diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index ac6f36791..c3457093c 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -97,6 +97,7 @@ class AutotunedKernelCall { size_t, size_t>> input_output_aliases); static absl::StatusOr Autotune(AutotunedKernelCall kernel_call, + gpuStream_t stream, void** buffers); static absl::StatusOr FromProto( diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 1b3023000..62ef6c4cd 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -66,6 +66,7 @@ typedef cublasStatus_t gpublasStatus_t; typedef cublasHandle_t gpublasHandle_t; typedef CUcontext gpuContext_t; typedef CUstreamCaptureMode gpustreamCaptureMode_t; +typedef CUstreamCaptureStatus gpustreamCaptureStatus_t; typedef cudaDataType gpuDataType; typedef CUdevice gpuDevice_t; typedef CUdeviceptr gpuDevicePtr_t; @@ -276,6 +277,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuThreadExchangeStreamCaptureMode cuThreadExchangeStreamCaptureMode #define gpuStreamCreate cuStreamCreate #define gpuStreamDestroy cuStreamDestroy +#define gpuStreamIsCapturing cuStreamIsCapturing #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \ CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR