The Triton autotuner ignores configs that use too much shmem

The autotuner runs a series of benchmarks to determine the best configuration
for a Triton kernel. However, if it encounters a config that does not fit in
shared memory it throws an error and stops. I this eventuality it should just
continue.

PiperOrigin-RevId: 600730507
This commit is contained in:
jax authors 2024-01-23 03:08:25 -08:00
parent 34d22fc498
commit 5761e393fa
2 changed files with 66 additions and 14 deletions

View File

@ -2,6 +2,7 @@
#include <algorithm>
#include <cstdint>
#include <cmath>
#include <memory>
#include <string>
#include <string_view>
@ -42,6 +43,29 @@ struct gpuModuleDeleter {
using OwnedGPUmodule =
std::unique_ptr<std::remove_pointer_t<gpuModule_t>, gpuModuleDeleter>;
absl::StatusOr<gpuDevice_t> GetStreamDevice(gpuStream_t stream) {
gpuDevice_t device;
gpuContext_t context;
#ifdef JAX_GPU_HIP
int device_id = gpuGetStreamDeviceId(stream);
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
#else // JAX_GPU_CUDA
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));
#endif
return device;
}
absl::StatusOr<uint32_t> MaxSharedMemoryPerBlock(gpuDevice_t device) {
int shared_optin;
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
&shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
return shared_optin;
}
absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
uint32_t shared_mem_bytes,
std::string_view ptx,
@ -248,19 +272,19 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
compute_capability_));
}
gpuContext_t context;
#ifdef JAX_GPU_HIP
int device_id = gpuGetStreamDeviceId(stream);
gpuDevice_t device;
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
module_image_->GetFunctionForContext(context));
return JAX_AS_STATUS(gpuLaunchKernel(
kernel, grid[0], grid[1], grid[2], block_dim_x_,
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
/*extra=*/nullptr));
#else //JAX_GPU_CUDA
gpuContext_t context;
#ifdef JAX_GPU_HIP
int device_id = gpuGetStreamDeviceId(stream);
gpuDevice_t device;
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
module_image_->GetFunctionForContext(context));
return JAX_AS_STATUS(gpuLaunchKernel(
kernel, grid[0], grid[1], grid[2], block_dim_x_,
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
/*extra=*/nullptr));
#else // JAX_GPU_CUDA
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
module_image_->GetFunctionForContext(context));
@ -347,6 +371,10 @@ KernelCall::Parameter::FromProto(
return param;
}
bool Kernel::CanLaunchOnDevice(gpuDevice_t device) const {
return shared_mem_bytes_ <= MaxSharedMemoryPerBlock(device).value_or(0);
}
jax_triton::TritonKernelCall_Parameter KernelCall::Parameter::ToProto() const {
jax_triton::TritonKernelCall_Parameter proto;
if (std::holds_alternative<Array>(value)) {
@ -436,6 +464,10 @@ jax_triton::TritonKernelCall KernelCall::ToProto() const {
return proto;
}
bool KernelCall::CanLaunchOnDevice(gpuDevice_t device) const {
return kernel_.CanLaunchOnDevice(device);
}
AutotunedKernelCall::AutotunedKernelCall(
std::string name, std::vector<Config> configs,
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases)
@ -527,7 +559,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
}
best = std::numeric_limits<float>::infinity();
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
for (Config& config : kernel_call.configs_) {
if (!config.kernel_call.CanLaunchOnDevice(device)) {
LOG(WARNING) << "Unable to launch autotune config on device: "
<< config.description;
continue;
}
JAX_ASSIGN_OR_RETURN(
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
@ -539,6 +578,11 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
std::swap(config, kernel_call.configs_[0]);
}
}
if (std::isinf(best)) {
LOG(WARNING) << "Finished autotuning function: " << kernel_call.name_
<< " no valid configs found.";
return absl::FailedPreconditionError("No launchable configs.");
}
LOG(INFO) << "Finished autotuning function: " << kernel_call.name_
<< " best config " << kernel_call.configs_[0].description;

View File

@ -10,6 +10,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
@ -33,6 +34,9 @@ class Kernel {
static Kernel FromProto(const jax_triton::TritonKernel& proto);
jax_triton::TritonKernel ToProto() const;
// Returns true if we can launch the kernel without crashing.
bool CanLaunchOnDevice(gpuDevice_t) const;
private:
std::string kernel_name_;
uint32_t block_dim_x_;
@ -71,6 +75,9 @@ class KernelCall {
const jax_triton::TritonKernelCall& proto);
jax_triton::TritonKernelCall ToProto() const;
// Returns true if we can launch the kernel without crashing.
bool CanLaunchOnDevice(gpuDevice_t) const;
private:
Kernel kernel_;
uint32_t grid_[3];
@ -86,7 +93,8 @@ class AutotunedKernelCall {
AutotunedKernelCall(
std::string name, std::vector<Config> configs,
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases);
std::vector<std::tuple<size_t,
size_t, size_t>> input_output_aliases);
static absl::StatusOr<KernelCall> Autotune(AutotunedKernelCall kernel_call,
gpuStream_t stream, void** buffers);