[jax_triton] Only use side stream to do autotuning when doing graph capture

When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
This commit is contained in:
Anlun Xu 2024-02-02 10:47:40 -08:00 committed by jax authors
parent e1ea936fc1
commit 16636f9c97
3 changed files with 34 additions and 20 deletions

View File

@ -156,7 +156,7 @@ absl::StatusOr<KernelCall*> GetKernelCall(absl::string_view opaque,
{
JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_,
AutotunedKernelCall::Autotune(
std::move(autotuned_call), buffers));
std::move(autotuned_call), stream, buffers));
kernel_call = std::make_unique<KernelCall>(std::move(kernel_call_));
}
} else {
@ -515,19 +515,26 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
}
/*static*/ absl::StatusOr<KernelCall> AutotunedKernelCall::Autotune(
AutotunedKernelCall kernel_call, void** buffers) {
AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) {
// Ensure a valid context for driver calls that don't take the stream.
//gpuContext_t context;
//GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
//GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
//absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED;
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
gpustreamCaptureStatus_t capture_status;
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
bool is_capturing = capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE;
// Need a side stream so as not to interfere with graph capture.
gpuStream_t stream;
GPU_RETURN_IF_ERROR(gpuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED;
gpuStream_t autotune_stream = stream;
if (is_capturing) {
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
// Need a side stream so as not to interfere with graph capture.
GPU_RETURN_IF_ERROR(
gpuStreamCreate(&autotune_stream, CU_STREAM_NON_BLOCKING));
}
// If an input aliases with an output, it will get overwritten during the
// kernel execution. If the kernel is called repeatedly, as we do during
@ -538,8 +545,9 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
if (buffers[input_idx] == buffers[output_idx]) {
std::vector<uint8_t> input_copy(size);
GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
input_copy.data(), reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
size, stream));
input_copy.data(),
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size,
autotune_stream));
input_copies[input_idx] = std::move(input_copy);
}
}
@ -549,8 +557,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
// iterations to run for benchmarking.
float best = std::numeric_limits<float>::infinity();
for (Config& config : kernel_call.configs_) {
JAX_ASSIGN_OR_RETURN(float t,
Benchmark(stream, config.kernel_call, buffers, 1));
JAX_ASSIGN_OR_RETURN(
float t, Benchmark(autotune_stream, config.kernel_call, buffers, 1));
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
best = std::min(best, t);
}
@ -566,7 +574,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
}
best = std::numeric_limits<float>::infinity();
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(autotune_stream));
for (Config& config : kernel_call.configs_) {
if (!config.kernel_call.CanLaunchOnDevice(device)) {
LOG(WARNING) << "Unable to launch autotune config on device: "
@ -574,8 +582,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
continue;
}
JAX_ASSIGN_OR_RETURN(
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
JAX_ASSIGN_OR_RETURN(float t, Benchmark(autotune_stream, config.kernel_call,
buffers, timed_iters));
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
<< t << " ms";
@ -596,15 +604,18 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
// Restore aliased inputs to their original values.
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
GPU_RETURN_IF_ERROR(
gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
input_copies[input_idx].data(), size, stream));
GPU_RETURN_IF_ERROR(gpuMemcpyHtoDAsync(
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
input_copies[input_idx].data(), size, autotune_stream));
}
// Synchronize stream to ensure copies are complete before the host copy
// is deleted.
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));
GPU_RETURN_IF_ERROR(gpuStreamDestroy(stream));
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream));
if (is_capturing) {
GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream));
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
}
return std::move(kernel_call.configs_[0].kernel_call);
}

View File

@ -97,6 +97,7 @@ class AutotunedKernelCall {
size_t, size_t>> input_output_aliases);
static absl::StatusOr<KernelCall> Autotune(AutotunedKernelCall kernel_call,
gpuStream_t stream,
void** buffers);
static absl::StatusOr<AutotunedKernelCall> FromProto(

View File

@ -66,6 +66,7 @@ typedef cublasStatus_t gpublasStatus_t;
typedef cublasHandle_t gpublasHandle_t;
typedef CUcontext gpuContext_t;
typedef CUstreamCaptureMode gpustreamCaptureMode_t;
typedef CUstreamCaptureStatus gpustreamCaptureStatus_t;
typedef cudaDataType gpuDataType;
typedef CUdevice gpuDevice_t;
typedef CUdeviceptr gpuDevicePtr_t;
@ -276,6 +277,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuThreadExchangeStreamCaptureMode cuThreadExchangeStreamCaptureMode
#define gpuStreamCreate cuStreamCreate
#define gpuStreamDestroy cuStreamDestroy
#define gpuStreamIsCapturing cuStreamIsCapturing
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR