[jax-triton] Do not capture jax-triton calls that require autotuning

PiperOrigin-RevId: 611823473
This commit is contained in:
Eugene Zhulenev 2024-03-01 10:28:11 -08:00 committed by jax authors
parent 8e2a8b7b95
commit 1ae2022918

View File

@ -545,30 +545,15 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
// GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
// absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
// If `stream` is in capture mode we can't run autotuning on it as we don't
// want to capture it into a graph. We create a new stream to do autotuning
// and destroy it when we are done.
// Autotuning is not supported if the the stream is in graph capture mode.
gpustreamCaptureStatus_t capture_status;
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;
gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
gpuStream_t autotune_stream = stream;
// An event that synchronizes autotuning stream with a main one.
gpuEvent_t autotune_event = nullptr;
if (is_capturing) {
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
// Record event after completion of launched kernels on the main stream.
GPU_RETURN_IF_ERROR(gpuEventCreate(&autotune_event, 0));
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, stream));
// Create a new stream to run autotuning and synchronize it with main sream.
GPU_RETURN_IF_ERROR(
gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(autotune_stream, autotune_event));
if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
return absl::FailedPreconditionError(
"Can't autotune Triton kernel when the stream is in graph capture "
"mode. Autotuning can rely on real data present in input buffers to "
"use them in address computation, but in graph capture mode buffers "
"can have arbitrary data");
}
// If an input aliases with an output, it will get overwritten during the
@ -581,8 +566,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
std::vector<uint8_t> input_copy(size);
GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
input_copy.data(),
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size,
autotune_stream));
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size, stream));
input_copies[input_idx] = std::move(input_copy);
}
}
@ -592,8 +576,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
// iterations to run for benchmarking.
float best = std::numeric_limits<float>::infinity();
for (Config& config : kernel_call.configs_) {
JAX_ASSIGN_OR_RETURN(
float t, Benchmark(autotune_stream, config.kernel_call, buffers, 1));
JAX_ASSIGN_OR_RETURN(float t,
Benchmark(stream, config.kernel_call, buffers, 1));
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
best = std::min(best, t);
}
@ -609,7 +593,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
}
best = std::numeric_limits<float>::infinity();
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(autotune_stream));
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
for (Config& config : kernel_call.configs_) {
if (!config.kernel_call.CanLaunchOnDevice(device)) {
LOG(WARNING) << "Unable to launch autotune config on device: "
@ -617,8 +601,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
continue;
}
JAX_ASSIGN_OR_RETURN(float t, Benchmark(autotune_stream, config.kernel_call,
buffers, timed_iters));
JAX_ASSIGN_OR_RETURN(
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
<< t << " ms";
@ -639,25 +623,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
// Restore aliased inputs to their original values.
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
GPU_RETURN_IF_ERROR(gpuMemcpyHtoDAsync(
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
input_copies[input_idx].data(), size, autotune_stream));
GPU_RETURN_IF_ERROR(
gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
input_copies[input_idx].data(), size, stream));
}
// Synchronize stream to ensure copies are complete before the host copy
// is deleted.
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream));
if (is_capturing) {
// Wait on a main stream for completion of autotuning.
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, autotune_stream));
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(stream, autotune_event));
GPU_RETURN_IF_ERROR(gpuEventDestroy(autotune_event));
// Destroy autotuning stream and recover stream capturing mode.
GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream));
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
}
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));
return std::move(kernel_call.configs_[0].kernel_call);
}