mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
The Triton autotuner ignores configs that use too much shmem
The autotuner runs a series of benchmarks to determine the best configuration for a Triton kernel. However, if it encounters a config that does not fit in shared memory it throws an error and stops. I this eventuality it should just continue. PiperOrigin-RevId: 600730507
This commit is contained in:
parent
34d22fc498
commit
5761e393fa
@ -2,6 +2,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
@ -42,6 +43,29 @@ struct gpuModuleDeleter {
|
||||
using OwnedGPUmodule =
|
||||
std::unique_ptr<std::remove_pointer_t<gpuModule_t>, gpuModuleDeleter>;
|
||||
|
||||
absl::StatusOr<gpuDevice_t> GetStreamDevice(gpuStream_t stream) {
|
||||
gpuDevice_t device;
|
||||
gpuContext_t context;
|
||||
#ifdef JAX_GPU_HIP
|
||||
int device_id = gpuGetStreamDeviceId(stream);
|
||||
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
|
||||
#else // JAX_GPU_CUDA
|
||||
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
|
||||
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
||||
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
|
||||
GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));
|
||||
#endif
|
||||
return device;
|
||||
}
|
||||
|
||||
absl::StatusOr<uint32_t> MaxSharedMemoryPerBlock(gpuDevice_t device) {
|
||||
int shared_optin;
|
||||
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
|
||||
&shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
return shared_optin;
|
||||
}
|
||||
|
||||
absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
|
||||
uint32_t shared_mem_bytes,
|
||||
std::string_view ptx,
|
||||
@ -248,19 +272,19 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
|
||||
compute_capability_));
|
||||
}
|
||||
|
||||
gpuContext_t context;
|
||||
#ifdef JAX_GPU_HIP
|
||||
int device_id = gpuGetStreamDeviceId(stream);
|
||||
gpuDevice_t device;
|
||||
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
|
||||
GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
|
||||
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
|
||||
module_image_->GetFunctionForContext(context));
|
||||
return JAX_AS_STATUS(gpuLaunchKernel(
|
||||
kernel, grid[0], grid[1], grid[2], block_dim_x_,
|
||||
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
|
||||
/*extra=*/nullptr));
|
||||
#else //JAX_GPU_CUDA
|
||||
gpuContext_t context;
|
||||
#ifdef JAX_GPU_HIP
|
||||
int device_id = gpuGetStreamDeviceId(stream);
|
||||
gpuDevice_t device;
|
||||
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
|
||||
GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
|
||||
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
|
||||
module_image_->GetFunctionForContext(context));
|
||||
return JAX_AS_STATUS(gpuLaunchKernel(
|
||||
kernel, grid[0], grid[1], grid[2], block_dim_x_,
|
||||
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
|
||||
/*extra=*/nullptr));
|
||||
#else // JAX_GPU_CUDA
|
||||
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
|
||||
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
|
||||
module_image_->GetFunctionForContext(context));
|
||||
@ -347,6 +371,10 @@ KernelCall::Parameter::FromProto(
|
||||
return param;
|
||||
}
|
||||
|
||||
bool Kernel::CanLaunchOnDevice(gpuDevice_t device) const {
|
||||
return shared_mem_bytes_ <= MaxSharedMemoryPerBlock(device).value_or(0);
|
||||
}
|
||||
|
||||
jax_triton::TritonKernelCall_Parameter KernelCall::Parameter::ToProto() const {
|
||||
jax_triton::TritonKernelCall_Parameter proto;
|
||||
if (std::holds_alternative<Array>(value)) {
|
||||
@ -436,6 +464,10 @@ jax_triton::TritonKernelCall KernelCall::ToProto() const {
|
||||
return proto;
|
||||
}
|
||||
|
||||
bool KernelCall::CanLaunchOnDevice(gpuDevice_t device) const {
|
||||
return kernel_.CanLaunchOnDevice(device);
|
||||
}
|
||||
|
||||
AutotunedKernelCall::AutotunedKernelCall(
|
||||
std::string name, std::vector<Config> configs,
|
||||
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases)
|
||||
@ -527,7 +559,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
}
|
||||
|
||||
best = std::numeric_limits<float>::infinity();
|
||||
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
|
||||
for (Config& config : kernel_call.configs_) {
|
||||
if (!config.kernel_call.CanLaunchOnDevice(device)) {
|
||||
LOG(WARNING) << "Unable to launch autotune config on device: "
|
||||
<< config.description;
|
||||
continue;
|
||||
}
|
||||
|
||||
JAX_ASSIGN_OR_RETURN(
|
||||
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
|
||||
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
|
||||
@ -539,6 +578,11 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
||||
std::swap(config, kernel_call.configs_[0]);
|
||||
}
|
||||
}
|
||||
if (std::isinf(best)) {
|
||||
LOG(WARNING) << "Finished autotuning function: " << kernel_call.name_
|
||||
<< " no valid configs found.";
|
||||
return absl::FailedPreconditionError("No launchable configs.");
|
||||
}
|
||||
|
||||
LOG(INFO) << "Finished autotuning function: " << kernel_call.name_
|
||||
<< " best config " << kernel_call.configs_[0].description;
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/triton.pb.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
@ -33,6 +34,9 @@ class Kernel {
|
||||
static Kernel FromProto(const jax_triton::TritonKernel& proto);
|
||||
jax_triton::TritonKernel ToProto() const;
|
||||
|
||||
// Returns true if we can launch the kernel without crashing.
|
||||
bool CanLaunchOnDevice(gpuDevice_t) const;
|
||||
|
||||
private:
|
||||
std::string kernel_name_;
|
||||
uint32_t block_dim_x_;
|
||||
@ -71,6 +75,9 @@ class KernelCall {
|
||||
const jax_triton::TritonKernelCall& proto);
|
||||
jax_triton::TritonKernelCall ToProto() const;
|
||||
|
||||
// Returns true if we can launch the kernel without crashing.
|
||||
bool CanLaunchOnDevice(gpuDevice_t) const;
|
||||
|
||||
private:
|
||||
Kernel kernel_;
|
||||
uint32_t grid_[3];
|
||||
@ -86,7 +93,8 @@ class AutotunedKernelCall {
|
||||
|
||||
AutotunedKernelCall(
|
||||
std::string name, std::vector<Config> configs,
|
||||
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases);
|
||||
std::vector<std::tuple<size_t,
|
||||
size_t, size_t>> input_output_aliases);
|
||||
|
||||
static absl::StatusOr<KernelCall> Autotune(AutotunedKernelCall kernel_call,
|
||||
gpuStream_t stream, void** buffers);
|
||||
|
Loading…
x
Reference in New Issue
Block a user