mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

This allows users to register the callback from C++ when not using the default call target name. PiperOrigin-RevId: 544029098
707 lines
26 KiB
C++
707 lines
26 KiB
C++
#include "jaxlib/gpu/triton.h"
|
|
|
|
#include <zlib.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <string_view>
|
|
#include <tuple>
|
|
#include <type_traits>
|
|
#include <variant>
|
|
#include <vector>
|
|
|
|
#include "pybind11/buffer_info.h"
|
|
#include "pybind11/pybind11.h"
|
|
#include "pybind11/pytypes.h"
|
|
#include "pybind11/stl.h"
|
|
#include "absl/base/call_once.h"
|
|
#include "absl/base/optimization.h"
|
|
#include "absl/cleanup/cleanup.h"
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "absl/log/check.h"
|
|
#include "absl/log/log.h"
|
|
#include "absl/status/status.h"
|
|
#include "absl/status/statusor.h"
|
|
#include "absl/strings/str_format.h"
|
|
#include "absl/synchronization/mutex.h"
|
|
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
|
#include "jaxlib/gpu/triton.pb.h"
|
|
#include "jaxlib/gpu/vendor.h"
|
|
#include "jaxlib/kernel_pybind11_helpers.h"
|
|
#include "pybind11_abseil/status_casters.h" // IWYU pragma: keep
|
|
#include "xla/service/custom_call_status.h"
|
|
#include "xla/stream_executor/gpu/asm_compiler.h"
|
|
|
|
#define CUDA_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
|
|
|
|
#define JAX_ASSIGN_OR_RETURN(lhs, rexpr) \
|
|
auto statusor = (rexpr); \
|
|
if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
|
|
return statusor.status(); \
|
|
} \
|
|
lhs = (*std::move(statusor))
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace jax::JAX_GPU_NAMESPACE {
|
|
|
|
// TODO(cjfj): Move this to `gpu_kernel_helpers`?
|
|
// Used via JAX_AS_STATUS(expr) macro.
|
|
absl::Status AsStatus(CUresult error, const char* file, std::int64_t line,
|
|
const char* expr) {
|
|
if (ABSL_PREDICT_TRUE(error == CUDA_SUCCESS)) {
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
const char* str;
|
|
CHECK_EQ(cuGetErrorName(error, &str), CUDA_SUCCESS);
|
|
return absl::InternalError(
|
|
absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr, str));
|
|
}
|
|
|
|
} // namespace jax::JAX_GPU_NAMESPACE
|
|
|
|
namespace jax_triton {
|
|
namespace {
|
|
|
|
constexpr uint32_t kNumThreadsPerWarp = 32;
|
|
|
|
struct CuModuleDeleter {
|
|
void operator()(CUmodule module) { cuModuleUnload(module); }
|
|
};
|
|
|
|
using OwnedCUmodule =
|
|
std::unique_ptr<std::remove_pointer_t<CUmodule>, CuModuleDeleter>;
|
|
|
|
class ModuleImage {
|
|
public:
|
|
ModuleImage(std::string_view kernel_name, std::vector<uint8_t> module_image,
|
|
uint32_t shared_mem_bytes)
|
|
: kernel_name_(kernel_name),
|
|
module_image_(std::move(module_image)),
|
|
shared_mem_bytes_(shared_mem_bytes) {}
|
|
|
|
absl::StatusOr<CUfunction> GetFunctionForContext(CUcontext context) {
|
|
absl::MutexLock lock(&mutex_);
|
|
auto it = functions_.find(context);
|
|
if (ABSL_PREDICT_TRUE(it != functions_.end())) {
|
|
return it->second;
|
|
}
|
|
|
|
CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(context));
|
|
absl::Cleanup ctx_restorer = [] { cuCtxPopCurrent(nullptr); };
|
|
|
|
CUmodule module;
|
|
CUDA_RETURN_IF_ERROR(cuModuleLoadData(&module, module_image_.data()));
|
|
modules_.push_back(OwnedCUmodule(module, CuModuleDeleter()));
|
|
|
|
CUfunction function;
|
|
CUDA_RETURN_IF_ERROR(
|
|
cuModuleGetFunction(&function, module, kernel_name_.c_str()));
|
|
auto [_, success] = functions_.insert({context, function});
|
|
CHECK(success);
|
|
|
|
// The maximum permitted static shared memory allocation in CUDA is 48kB,
|
|
// but we can expose more to the kernel using dynamic shared memory.
|
|
constexpr int kMaxStaticSharedMemBytes = 49152;
|
|
if (shared_mem_bytes_ <= kMaxStaticSharedMemBytes) {
|
|
return function;
|
|
}
|
|
|
|
// Set up dynamic shared memory.
|
|
CUdevice device;
|
|
CUDA_RETURN_IF_ERROR(cuCtxGetDevice(&device));
|
|
|
|
int shared_optin;
|
|
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
|
|
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
|
device));
|
|
|
|
if (shared_mem_bytes_ > shared_optin) {
|
|
return absl::InvalidArgumentError(
|
|
"Shared memory requested exceeds device resources.");
|
|
}
|
|
|
|
if (shared_optin > kMaxStaticSharedMemBytes) {
|
|
CUDA_RETURN_IF_ERROR(
|
|
cuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED));
|
|
int shared_total;
|
|
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
|
|
&shared_total,
|
|
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
|
|
int shared_static;
|
|
CUDA_RETURN_IF_ERROR(cuFuncGetAttribute(
|
|
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
|
|
CUDA_RETURN_IF_ERROR(cuFuncSetAttribute(
|
|
function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
|
shared_optin - shared_static));
|
|
}
|
|
return function;
|
|
}
|
|
|
|
private:
|
|
std::string kernel_name_;
|
|
std::vector<uint8_t> module_image_;
|
|
uint32_t shared_mem_bytes_;
|
|
|
|
absl::Mutex mutex_;
|
|
std::vector<OwnedCUmodule> modules_ ABSL_GUARDED_BY(mutex_);
|
|
absl::flat_hash_map<CUcontext, CUfunction> functions_ ABSL_GUARDED_BY(mutex_);
|
|
};
|
|
|
|
absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
|
|
uint32_t shared_mem_bytes,
|
|
std::string_view ptx,
|
|
int compute_capability) {
|
|
auto key =
|
|
std::make_tuple(kernel_name, shared_mem_bytes, ptx, compute_capability);
|
|
|
|
static absl::Mutex mutex;
|
|
static auto& module_images =
|
|
*new absl::flat_hash_map<decltype(key), std::unique_ptr<ModuleImage>>
|
|
ABSL_GUARDED_BY(mutex);
|
|
|
|
absl::MutexLock lock(&mutex);
|
|
auto it = module_images.find(key);
|
|
if (it != module_images.end()) return it->second.get();
|
|
|
|
// TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
|
|
int cc_major = compute_capability / 10;
|
|
int cc_minor = compute_capability % 10;
|
|
JAX_ASSIGN_OR_RETURN(
|
|
std::vector<uint8_t> module_image,
|
|
stream_executor::CompileGpuAsm(cc_major, cc_minor, ptx.data(),
|
|
stream_executor::GpuAsmOpts{}));
|
|
|
|
auto [it2, success] = module_images.insert(
|
|
{std::move(key),
|
|
std::make_unique<ModuleImage>(
|
|
std::move(kernel_name), std::move(module_image), shared_mem_bytes)});
|
|
CHECK(success);
|
|
return it2->second.get();
|
|
}
|
|
|
|
class Kernel {
|
|
public:
|
|
Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes,
|
|
std::string ptx, std::string ttir, int compute_capability)
|
|
: kernel_name_(std::move(kernel_name)),
|
|
block_dim_x_(num_warps * kNumThreadsPerWarp),
|
|
shared_mem_bytes_(shared_mem_bytes),
|
|
ptx_(std::move(ptx)),
|
|
ttir_(std::move(ttir)),
|
|
compute_capability_(compute_capability) {}
|
|
|
|
absl::Status Launch(CUstream stream, uint32_t grid[3], void** params) {
|
|
if (ABSL_PREDICT_FALSE(module_image_ == nullptr)) {
|
|
JAX_ASSIGN_OR_RETURN(module_image_,
|
|
GetModuleImage(kernel_name_, shared_mem_bytes_, ptx_,
|
|
compute_capability_));
|
|
}
|
|
|
|
CUcontext context;
|
|
CUDA_RETURN_IF_ERROR(cuStreamGetCtx(stream, &context));
|
|
JAX_ASSIGN_OR_RETURN(CUfunction kernel,
|
|
module_image_->GetFunctionForContext(context));
|
|
return JAX_AS_STATUS(cuLaunchKernel(
|
|
kernel, grid[0], grid[1], grid[2], block_dim_x_,
|
|
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
|
|
/*extra=*/nullptr));
|
|
}
|
|
|
|
static Kernel FromProto(const TritonKernel& proto) {
|
|
return Kernel(proto.kernel_name(), proto.num_warps(),
|
|
proto.shared_mem_bytes(), proto.ptx(), proto.ttir(),
|
|
proto.compute_capability());
|
|
}
|
|
|
|
TritonKernel ToProto() const {
|
|
TritonKernel proto;
|
|
proto.set_kernel_name(kernel_name_);
|
|
proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp);
|
|
proto.set_shared_mem_bytes(shared_mem_bytes_);
|
|
proto.set_ptx(ptx_);
|
|
proto.set_ttir(ttir_);
|
|
proto.set_compute_capability(compute_capability_);
|
|
return proto;
|
|
}
|
|
|
|
private:
|
|
std::string kernel_name_;
|
|
uint32_t block_dim_x_;
|
|
uint32_t shared_mem_bytes_;
|
|
std::string ptx_;
|
|
std::string ttir_;
|
|
int compute_capability_;
|
|
|
|
ModuleImage* module_image_ = nullptr;
|
|
};
|
|
|
|
struct KernelCallBase {
|
|
virtual ~KernelCallBase() = default;
|
|
virtual absl::Status Launch(CUstream stream, void** buffers) = 0;
|
|
};
|
|
|
|
class KernelCall : public KernelCallBase {
|
|
public:
|
|
struct Parameter {
|
|
struct Array {
|
|
size_t bytes_to_zero;
|
|
size_t ptr_divisibility;
|
|
};
|
|
|
|
static absl::StatusOr<Parameter> FromProto(
|
|
const TritonKernelCall_Parameter& proto) {
|
|
Parameter param;
|
|
switch (proto.value_case()) {
|
|
case TritonKernelCall_Parameter::kArray:
|
|
param.value = Array{proto.array().bytes_to_zero(),
|
|
proto.array().ptr_divisibility()};
|
|
break;
|
|
case TritonKernelCall_Parameter::kBool:
|
|
param.value = proto.bool_();
|
|
break;
|
|
case TritonKernelCall_Parameter::kI32:
|
|
param.value = proto.i32();
|
|
break;
|
|
case TritonKernelCall_Parameter::kU32:
|
|
param.value = proto.u32();
|
|
break;
|
|
case TritonKernelCall_Parameter::kI64:
|
|
param.value = proto.i64();
|
|
break;
|
|
case TritonKernelCall_Parameter::kU64:
|
|
param.value = proto.u64();
|
|
break;
|
|
default:
|
|
return absl::InvalidArgumentError("Unknown scalar parameter type.");
|
|
}
|
|
return param;
|
|
}
|
|
|
|
TritonKernelCall_Parameter ToProto() const {
|
|
TritonKernelCall_Parameter proto;
|
|
if (std::holds_alternative<Array>(value)) {
|
|
proto.mutable_array()->set_bytes_to_zero(
|
|
std::get<Array>(value).bytes_to_zero);
|
|
proto.mutable_array()->set_ptr_divisibility(
|
|
std::get<Array>(value).ptr_divisibility);
|
|
} else if (std::holds_alternative<bool>(value)) {
|
|
proto.set_bool_(std::get<bool>(value));
|
|
} else if (std::holds_alternative<int32_t>(value)) {
|
|
proto.set_i32(std::get<int32_t>(value));
|
|
} else if (std::holds_alternative<uint32_t>(value)) {
|
|
proto.set_u32(std::get<uint32_t>(value));
|
|
} else if (std::holds_alternative<int64_t>(value)) {
|
|
proto.set_i64(std::get<int64_t>(value));
|
|
} else {
|
|
CHECK(std::holds_alternative<uint64_t>(value));
|
|
proto.set_u64(std::get<uint64_t>(value));
|
|
}
|
|
return proto;
|
|
}
|
|
|
|
std::variant<Array, bool, int32_t, uint32_t, int64_t, uint64_t> value;
|
|
};
|
|
|
|
KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2,
|
|
std::vector<Parameter> parameters)
|
|
: kernel_(std::move(kernel)),
|
|
grid_{grid_0, grid_1, grid_2},
|
|
parameters_(std::move(parameters)) {}
|
|
|
|
absl::Status Launch(CUstream stream, void** buffers) override final {
|
|
std::vector<void*> params;
|
|
params.reserve(parameters_.size());
|
|
for (size_t i = 0; i < parameters_.size(); ++i) {
|
|
const Parameter& param = parameters_[i];
|
|
if (std::holds_alternative<Parameter::Array>(param.value)) {
|
|
const auto& array = std::get<Parameter::Array>(param.value);
|
|
void*& ptr = *(buffers++);
|
|
auto cu_ptr = reinterpret_cast<CUdeviceptr>(ptr);
|
|
|
|
if (ABSL_PREDICT_FALSE((array.ptr_divisibility != 0) &&
|
|
(cu_ptr % array.ptr_divisibility != 0))) {
|
|
return absl::InvalidArgumentError(
|
|
absl::StrFormat("Parameter %zu (%p) is not divisible by %d.", i,
|
|
ptr, array.ptr_divisibility));
|
|
}
|
|
|
|
if (array.bytes_to_zero > 0) {
|
|
CUDA_RETURN_IF_ERROR(
|
|
cuMemsetD8Async(cu_ptr, 0, array.bytes_to_zero, stream));
|
|
}
|
|
params.push_back(&ptr);
|
|
} else {
|
|
params.push_back(const_cast<void*>(std::visit(
|
|
[](auto&& arg) { return reinterpret_cast<const void*>(&arg); },
|
|
param.value)));
|
|
}
|
|
}
|
|
|
|
return kernel_.Launch(stream, grid_, params.data());
|
|
}
|
|
|
|
static absl::StatusOr<KernelCall> FromProto(const TritonKernelCall& proto) {
|
|
std::vector<KernelCall::Parameter> parameters;
|
|
for (const TritonKernelCall_Parameter& parameter : proto.parameters()) {
|
|
JAX_ASSIGN_OR_RETURN(Parameter p, Parameter::FromProto(parameter));
|
|
parameters.push_back(p);
|
|
}
|
|
|
|
return KernelCall(Kernel::FromProto(proto.kernel()), proto.grid_0(),
|
|
proto.grid_1(), proto.grid_2(), std::move(parameters));
|
|
}
|
|
|
|
TritonKernelCall ToProto() const {
|
|
TritonKernelCall proto;
|
|
*proto.mutable_kernel() = kernel_.ToProto();
|
|
proto.set_grid_0(grid_[0]);
|
|
proto.set_grid_1(grid_[1]);
|
|
proto.set_grid_2(grid_[2]);
|
|
for (const Parameter& param : parameters_) {
|
|
*proto.add_parameters() = param.ToProto();
|
|
}
|
|
return proto;
|
|
}
|
|
|
|
private:
|
|
Kernel kernel_;
|
|
uint32_t grid_[3];
|
|
std::vector<Parameter> parameters_;
|
|
};
|
|
|
|
class AutotunedKernelCall : public KernelCallBase {
|
|
public:
|
|
struct Config {
|
|
KernelCall kernel_call;
|
|
std::string description;
|
|
};
|
|
|
|
AutotunedKernelCall(
|
|
std::string name, std::vector<Config> configs,
|
|
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases)
|
|
: name_(std::move(name)),
|
|
configs_(std::move(configs)),
|
|
input_output_aliases_(std::move(input_output_aliases)) {}
|
|
|
|
absl::Status Launch(CUstream stream, void** buffers) override {
|
|
absl::call_once(autotune_once_, [=]() {
|
|
if (configs_.size() > 1) {
|
|
autotune_status_ = Autotune(stream, buffers);
|
|
}
|
|
});
|
|
JAX_RETURN_IF_ERROR(autotune_status_);
|
|
return configs_[0].kernel_call.Launch(stream, buffers);
|
|
}
|
|
|
|
static absl::StatusOr<std::unique_ptr<AutotunedKernelCall>> FromProto(
|
|
const TritonAutotunedKernelCall& proto) {
|
|
std::vector<Config> configs;
|
|
for (const TritonAutotunedKernelCall_Config& config : proto.configs()) {
|
|
JAX_ASSIGN_OR_RETURN(auto kernel_call,
|
|
KernelCall::FromProto(config.kernel_call()));
|
|
configs.push_back(Config{std::move(kernel_call), config.description()});
|
|
}
|
|
|
|
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases;
|
|
for (const TritonAutotunedKernelCall_InputOutputAlias& a :
|
|
proto.input_output_aliases()) {
|
|
input_output_aliases.push_back(std::make_tuple(
|
|
a.input_buffer_idx(), a.output_buffer_idx(), a.buffer_size_bytes()));
|
|
}
|
|
|
|
return std::make_unique<AutotunedKernelCall>(
|
|
proto.name(), std::move(configs), std::move(input_output_aliases));
|
|
}
|
|
|
|
TritonAutotunedKernelCall ToProto() const {
|
|
TritonAutotunedKernelCall proto;
|
|
proto.set_name(name_);
|
|
for (const Config& config : configs_) {
|
|
TritonAutotunedKernelCall_Config* c = proto.add_configs();
|
|
*c->mutable_kernel_call() = config.kernel_call.ToProto();
|
|
c->set_description(config.description);
|
|
}
|
|
for (const auto& [input_idx, output_idx, size] : input_output_aliases_) {
|
|
TritonAutotunedKernelCall_InputOutputAlias* a =
|
|
proto.add_input_output_aliases();
|
|
a->set_input_buffer_idx(input_idx);
|
|
a->set_output_buffer_idx(output_idx);
|
|
a->set_buffer_size_bytes(size);
|
|
}
|
|
return proto;
|
|
}
|
|
|
|
private:
|
|
static constexpr float kBenchmarkTimeMillis = 10.;
|
|
|
|
absl::Status Autotune(CUstream stream, void** buffers) {
|
|
// Ensure a valid context for driver calls that don't take the stream.
|
|
CUcontext context;
|
|
CUDA_RETURN_IF_ERROR(cuStreamGetCtx(stream, &context));
|
|
CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(context));
|
|
absl::Cleanup ctx_restorer = [] { cuCtxPopCurrent(nullptr); };
|
|
|
|
// If an input aliases with an output, it will get overwritten during the
|
|
// kernel execution. If the kernel is called repeatedly, as we do during
|
|
// auto-tuning, the final result will be junk, so we take a copy of the
|
|
// input to restore after auto-tuning.
|
|
std::unordered_map<size_t, std::vector<uint8_t>> input_copies;
|
|
for (auto [input_idx, output_idx, size] : input_output_aliases_) {
|
|
if (buffers[input_idx] == buffers[output_idx]) {
|
|
std::vector<uint8_t> input_copy(size);
|
|
CUDA_RETURN_IF_ERROR(cuMemcpyDtoHAsync(
|
|
input_copy.data(),
|
|
reinterpret_cast<CUdeviceptr>(buffers[input_idx]), size, stream));
|
|
input_copies[input_idx] = std::move(input_copy);
|
|
}
|
|
}
|
|
|
|
LOG(INFO) << "Autotuning function: " << name_;
|
|
// First run a single iteration of each to config to determine how many
|
|
// iterations to run for benchmarking.
|
|
float best = std::numeric_limits<float>::infinity();
|
|
for (Config& config : configs_) {
|
|
JAX_ASSIGN_OR_RETURN(float t,
|
|
Benchmark(stream, config.kernel_call, buffers, 1));
|
|
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
|
|
best = std::min(best, t);
|
|
}
|
|
|
|
int timed_iters =
|
|
std::max(static_cast<int>(kBenchmarkTimeMillis / best), 1);
|
|
if (timed_iters > 100) {
|
|
timed_iters = 100;
|
|
LOG(INFO) << "Benchmarking with 100 iters (capped at 100)";
|
|
} else {
|
|
timed_iters = std::min(timed_iters, 100);
|
|
LOG(INFO) << "Benchmarking with " << timed_iters
|
|
<< " iters (target time: " << kBenchmarkTimeMillis << " ms)";
|
|
}
|
|
|
|
best = std::numeric_limits<float>::infinity();
|
|
for (Config& config : configs_) {
|
|
JAX_ASSIGN_OR_RETURN(
|
|
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
|
|
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
|
|
<< t << " ms";
|
|
|
|
if (t < best) {
|
|
LOG(INFO) << config.description << " is the new best config";
|
|
best = t;
|
|
std::swap(config, configs_[0]);
|
|
}
|
|
}
|
|
|
|
// Discard all but the best config.
|
|
configs_.erase(configs_.begin() + 1, configs_.end());
|
|
|
|
LOG(INFO) << "Finished autotuning function: " << name_ << " best config "
|
|
<< configs_[0].description;
|
|
|
|
// Restore aliased inputs to their original values.
|
|
for (auto [input_idx, _, size] : input_output_aliases_) {
|
|
CUDA_RETURN_IF_ERROR(
|
|
cuMemcpyHtoDAsync(reinterpret_cast<CUdeviceptr>(buffers[input_idx]),
|
|
input_copies[input_idx].data(), size, stream));
|
|
}
|
|
// Synchronize stream to ensure copies are complete before the host copy
|
|
// is deleted.
|
|
return JAX_AS_STATUS(cuStreamSynchronize(stream));
|
|
}
|
|
|
|
absl::StatusOr<float> Benchmark(CUstream stream, KernelCall& kernel_call,
|
|
void** buffers, int num_iterations) {
|
|
CUevent start, stop;
|
|
CUDA_RETURN_IF_ERROR(cuEventCreate(&start, /*Flags=*/CU_EVENT_DEFAULT));
|
|
CUDA_RETURN_IF_ERROR(cuEventCreate(&stop, /*Flags=*/CU_EVENT_DEFAULT));
|
|
JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers)); // Warm-up.
|
|
CUDA_RETURN_IF_ERROR(cuEventRecord(start, stream));
|
|
for (int i = 0; i < num_iterations; ++i) {
|
|
JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers));
|
|
}
|
|
CUDA_RETURN_IF_ERROR(cuEventRecord(stop, stream));
|
|
CUDA_RETURN_IF_ERROR(cuEventSynchronize(stop));
|
|
float elapsed_ms;
|
|
CUDA_RETURN_IF_ERROR(cuEventElapsedTime(&elapsed_ms, start, stop));
|
|
CUDA_RETURN_IF_ERROR(cuEventDestroy(start));
|
|
CUDA_RETURN_IF_ERROR(cuEventDestroy(stop));
|
|
return elapsed_ms;
|
|
}
|
|
|
|
std::string name_;
|
|
// After auto-tuning, all configurations, except the best, will be discarded.
|
|
std::vector<Config> configs_;
|
|
// (input buffer idx, output buffer idx, size)
|
|
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases_;
|
|
absl::once_flag autotune_once_;
|
|
absl::Status autotune_status_;
|
|
};
|
|
|
|
absl::StatusOr<KernelCallBase*> GetKernelCall(absl::string_view opaque) {
|
|
static absl::Mutex mutex;
|
|
static auto& kernel_calls =
|
|
*new absl::flat_hash_map<std::string, std::unique_ptr<KernelCallBase>>
|
|
ABSL_GUARDED_BY(mutex);
|
|
|
|
absl::MutexLock lock(&mutex);
|
|
auto it = kernel_calls.find(opaque);
|
|
if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) return it->second.get();
|
|
|
|
// The opaque data is a zlib compressed protobuf.
|
|
std::string serialized;
|
|
uLongf dest_len = 5 * opaque.size();
|
|
while (true) {
|
|
serialized.resize(dest_len);
|
|
int ret = uncompress(reinterpret_cast<Bytef*>(serialized.data()), &dest_len,
|
|
reinterpret_cast<const Bytef*>(opaque.data()),
|
|
opaque.size());
|
|
if (ret == Z_OK) {
|
|
// `uncompress` overwrites `dest_len` with the uncompressed size.
|
|
serialized.resize(dest_len);
|
|
break;
|
|
} else if (ret == Z_BUF_ERROR) {
|
|
dest_len *= 2; // The string buffer wasn't large enough.
|
|
} else {
|
|
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
|
|
}
|
|
}
|
|
|
|
TritonAnyKernelCall proto;
|
|
if (!proto.ParseFromString(serialized)) {
|
|
return absl::InvalidArgumentError("Failed to parse serialized data.");
|
|
}
|
|
|
|
std::unique_ptr<KernelCallBase> kernel_call;
|
|
if (proto.has_kernel_call()) {
|
|
JAX_ASSIGN_OR_RETURN(auto kernel_call_,
|
|
KernelCall::FromProto(proto.kernel_call()));
|
|
kernel_call = std::make_unique<KernelCall>(std::move(kernel_call_));
|
|
} else if (proto.has_autotuned_kernel_call()) {
|
|
JAX_ASSIGN_OR_RETURN(kernel_call, AutotunedKernelCall::FromProto(
|
|
proto.autotuned_kernel_call()));
|
|
} else {
|
|
return absl::InvalidArgumentError("Unknown kernel call type.");
|
|
}
|
|
|
|
auto [it2, success] =
|
|
kernel_calls.insert({std::string(opaque), std::move(kernel_call)});
|
|
CHECK(success);
|
|
return it2->second.get();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void LaunchTritonKernel(CUstream stream, void** buffers, const char* opaque,
|
|
size_t opaque_len, XlaCustomCallStatus* status) {
|
|
absl::Status result = [=] {
|
|
JAX_ASSIGN_OR_RETURN(KernelCallBase * kernel_call,
|
|
GetKernelCall(absl::string_view(opaque, opaque_len)));
|
|
return kernel_call->Launch(stream, buffers);
|
|
}();
|
|
if (!result.ok()) {
|
|
absl::string_view msg = result.message();
|
|
XlaCustomCallStatusSetFailure(status, msg.data(), msg.length());
|
|
}
|
|
}
|
|
|
|
PYBIND11_MODULE(_triton, m) {
|
|
py::class_<Kernel>(m, "TritonKernel")
|
|
.def(py::init<std::string, uint32_t, uint32_t, std::string, std::string,
|
|
int>());
|
|
|
|
py::class_<KernelCall::Parameter>(m, "TritonParameter");
|
|
|
|
m.def("create_array_parameter",
|
|
[](size_t bytes_to_zero, size_t ptr_divisibility) {
|
|
return KernelCall::Parameter{
|
|
KernelCall::Parameter::Array{bytes_to_zero, ptr_divisibility}};
|
|
});
|
|
|
|
m.def("create_scalar_parameter",
|
|
[](py::bool_ value,
|
|
std::string_view dtype) -> absl::StatusOr<KernelCall::Parameter> {
|
|
if ((dtype == "int1") || (dtype == "B")) {
|
|
return KernelCall::Parameter{static_cast<bool>(value)};
|
|
} else {
|
|
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
|
dtype.data());
|
|
}
|
|
});
|
|
|
|
m.def("create_scalar_parameter",
|
|
[](py::int_ value,
|
|
std::string_view dtype) -> absl::StatusOr<KernelCall::Parameter> {
|
|
if (dtype == "i32") {
|
|
return KernelCall::Parameter{static_cast<int32_t>(value)};
|
|
} else if (dtype == "u32") {
|
|
return KernelCall::Parameter{static_cast<uint32_t>(value)};
|
|
} else if (dtype == "i64") {
|
|
return KernelCall::Parameter{static_cast<int64_t>(value)};
|
|
} else if (dtype == "u64") {
|
|
return KernelCall::Parameter{static_cast<uint64_t>(value)};
|
|
} else {
|
|
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
|
dtype.data());
|
|
}
|
|
});
|
|
|
|
py::class_<KernelCall>(m, "TritonKernelCall")
|
|
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
|
|
std::vector<KernelCall::Parameter>>())
|
|
.def("to_proto",
|
|
[](const KernelCall& kernel_call,
|
|
py::bytes metadata) -> absl::StatusOr<py::bytes> {
|
|
TritonAnyKernelCall proto;
|
|
*proto.mutable_kernel_call() = kernel_call.ToProto();
|
|
py::buffer_info metadata_info(py::buffer(metadata).request());
|
|
proto.set_metadata(metadata_info.ptr, metadata_info.size);
|
|
std::string serialized = proto.SerializeAsString();
|
|
return py::bytes(serialized.data(), serialized.size());
|
|
});
|
|
|
|
py::class_<AutotunedKernelCall>(m, "TritonAutotunedKernelCall")
|
|
.def(py::init<>([](std::string name,
|
|
std::vector<std::pair<KernelCall, std::string>>
|
|
calls_and_descriptions,
|
|
std::vector<std::tuple<size_t, size_t, size_t>>
|
|
input_output_aliases) {
|
|
std::vector<AutotunedKernelCall::Config> configs;
|
|
configs.reserve(calls_and_descriptions.size());
|
|
for (auto& [kernel_call, desc] : calls_and_descriptions) {
|
|
configs.push_back({std::move(kernel_call), std::move(desc)});
|
|
}
|
|
return std::make_unique<AutotunedKernelCall>(
|
|
std::move(name), std::move(configs),
|
|
std::move(input_output_aliases));
|
|
}))
|
|
.def("to_proto",
|
|
[](const AutotunedKernelCall& kernel_call,
|
|
py::bytes metadata) -> absl::StatusOr<py::bytes> {
|
|
TritonAnyKernelCall proto;
|
|
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
|
|
py::buffer_info metadata_info(py::buffer(metadata).request());
|
|
proto.set_metadata(metadata_info.ptr, metadata_info.size);
|
|
std::string serialized = proto.SerializeAsString();
|
|
return py::bytes(serialized.data(), serialized.size());
|
|
});
|
|
|
|
m.def("get_custom_call",
|
|
[] { return jax::EncapsulateFunction(&LaunchTritonKernel); });
|
|
|
|
m.def("get_compute_capability", [](int device) -> absl::StatusOr<int> {
|
|
int major, minor;
|
|
CUDA_RETURN_IF_ERROR(cuInit(device));
|
|
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
|
|
&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
|
|
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
|
|
return major * 10 + minor;
|
|
});
|
|
}
|
|
|
|
} // namespace jax_triton
|