mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
This commit is contained in:
parent
8e2a8b7b95
commit
1ae2022918
@ -545,30 +545,15 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
// GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
||||
// absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
|
||||
|
||||
// If `stream` is in capture mode we can't run autotuning on it as we don't
|
||||
// want to capture it into a graph. We create a new stream to do autotuning
|
||||
// and destroy it when we are done.
|
||||
// Autotuning is not supported if the the stream is in graph capture mode.
|
||||
gpustreamCaptureStatus_t capture_status;
|
||||
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
|
||||
bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;
|
||||
|
||||
gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
|
||||
gpuStream_t autotune_stream = stream;
|
||||
|
||||
// An event that synchronizes autotuning stream with a main one.
|
||||
gpuEvent_t autotune_event = nullptr;
|
||||
|
||||
if (is_capturing) {
|
||||
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
|
||||
|
||||
// Record event after completion of launched kernels on the main stream.
|
||||
GPU_RETURN_IF_ERROR(gpuEventCreate(&autotune_event, 0));
|
||||
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, stream));
|
||||
|
||||
// Create a new stream to run autotuning and synchronize it with main sream.
|
||||
GPU_RETURN_IF_ERROR(
|
||||
gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));
|
||||
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(autotune_stream, autotune_event));
|
||||
if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
|
||||
return absl::FailedPreconditionError(
|
||||
"Can't autotune Triton kernel when the stream is in graph capture "
|
||||
"mode. Autotuning can rely on real data present in input buffers to "
|
||||
"use them in address computation, but in graph capture mode buffers "
|
||||
"can have arbitrary data");
|
||||
}
|
||||
|
||||
// If an input aliases with an output, it will get overwritten during the
|
||||
@ -581,8 +566,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
std::vector<uint8_t> input_copy(size);
|
||||
GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
|
||||
input_copy.data(),
|
||||
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size,
|
||||
autotune_stream));
|
||||
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size, stream));
|
||||
input_copies[input_idx] = std::move(input_copy);
|
||||
}
|
||||
}
|
||||
@ -592,8 +576,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
// iterations to run for benchmarking.
|
||||
float best = std::numeric_limits<float>::infinity();
|
||||
for (Config& config : kernel_call.configs_) {
|
||||
JAX_ASSIGN_OR_RETURN(
|
||||
float t, Benchmark(autotune_stream, config.kernel_call, buffers, 1));
|
||||
JAX_ASSIGN_OR_RETURN(float t,
|
||||
Benchmark(stream, config.kernel_call, buffers, 1));
|
||||
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
|
||||
best = std::min(best, t);
|
||||
}
|
||||
@ -609,7 +593,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
}
|
||||
|
||||
best = std::numeric_limits<float>::infinity();
|
||||
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(autotune_stream));
|
||||
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
|
||||
for (Config& config : kernel_call.configs_) {
|
||||
if (!config.kernel_call.CanLaunchOnDevice(device)) {
|
||||
LOG(WARNING) << "Unable to launch autotune config on device: "
|
||||
@ -617,8 +601,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
continue;
|
||||
}
|
||||
|
||||
JAX_ASSIGN_OR_RETURN(float t, Benchmark(autotune_stream, config.kernel_call,
|
||||
buffers, timed_iters));
|
||||
JAX_ASSIGN_OR_RETURN(
|
||||
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
|
||||
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
|
||||
<< t << " ms";
|
||||
|
||||
@ -639,25 +623,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
|
||||
// Restore aliased inputs to their original values.
|
||||
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
|
||||
GPU_RETURN_IF_ERROR(gpuMemcpyHtoDAsync(
|
||||
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
|
||||
input_copies[input_idx].data(), size, autotune_stream));
|
||||
GPU_RETURN_IF_ERROR(
|
||||
gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
|
||||
input_copies[input_idx].data(), size, stream));
|
||||
}
|
||||
|
||||
// Synchronize stream to ensure copies are complete before the host copy
|
||||
// is deleted.
|
||||
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream));
|
||||
|
||||
if (is_capturing) {
|
||||
// Wait on a main stream for completion of autotuning.
|
||||
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, autotune_stream));
|
||||
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(stream, autotune_event));
|
||||
GPU_RETURN_IF_ERROR(gpuEventDestroy(autotune_event));
|
||||
|
||||
// Destroy autotuning stream and recover stream capturing mode.
|
||||
GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream));
|
||||
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
|
||||
}
|
||||
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));
|
||||
|
||||
return std::move(kernel_call.configs_[0].kernel_call);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user