2023-07-03 06:51:45 -07:00
|
|
|
#include "jaxlib/gpu/triton_kernels.h"
|
|
|
|
|
|
|
|
#include <algorithm>
|
2024-01-23 03:08:25 -08:00
|
|
|
#include <cmath>
|
2024-02-23 11:41:50 -08:00
|
|
|
#include <cstddef>
|
|
|
|
#include <cstdint>
|
|
|
|
#include <limits>
|
2023-07-03 06:51:45 -07:00
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
|
|
|
#include <string_view>
|
|
|
|
#include <tuple>
|
|
|
|
#include <type_traits>
|
2024-02-23 11:41:50 -08:00
|
|
|
#include <unordered_map>
|
|
|
|
#include <utility>
|
2023-07-03 06:51:45 -07:00
|
|
|
#include <variant>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include "absl/base/optimization.h"
|
2024-02-23 11:41:50 -08:00
|
|
|
#include "absl/base/thread_annotations.h"
|
2023-07-03 06:51:45 -07:00
|
|
|
#include "absl/cleanup/cleanup.h"
|
|
|
|
#include "absl/container/flat_hash_map.h"
|
2024-04-18 14:35:06 -07:00
|
|
|
#include "absl/container/flat_hash_set.h"
|
2023-07-03 06:51:45 -07:00
|
|
|
#include "absl/log/check.h"
|
|
|
|
#include "absl/log/log.h"
|
|
|
|
#include "absl/status/status.h"
|
|
|
|
#include "absl/status/statusor.h"
|
|
|
|
#include "absl/strings/str_format.h"
|
|
|
|
#include "absl/synchronization/mutex.h"
|
|
|
|
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
|
|
|
#include "jaxlib/gpu/triton.pb.h"
|
2023-08-16 02:53:05 -07:00
|
|
|
#include "jaxlib/gpu/triton_utils.h"
|
2023-07-03 06:51:45 -07:00
|
|
|
#include "jaxlib/gpu/vendor.h"
|
|
|
|
#include "xla/service/custom_call_status.h"
|
|
|
|
|
2024-04-09 14:42:46 -07:00
|
|
|
#ifdef JAX_GPU_CUDA
|
|
|
|
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
|
2024-08-26 17:03:27 -07:00
|
|
|
#endif // JAX_GPU_CUDA
|
|
|
|
|
|
|
|
#ifdef JAX_GPU_HIP
|
|
|
|
#include "tsl/platform/env.h"
|
|
|
|
#endif // JAX_GPU_HIP
|
2024-04-09 14:42:46 -07:00
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
|
2023-07-03 06:51:45 -07:00
|
|
|
|
|
|
|
namespace jax::JAX_GPU_NAMESPACE {
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
constexpr float kBenchmarkTimeMillis = 10.;
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
struct gpuModuleDeleter {
|
2024-08-26 17:03:27 -07:00
|
|
|
void operator()(gpuModule_t module) {
|
|
|
|
absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module));
|
|
|
|
if (!status.ok()) {
|
|
|
|
LOG(WARNING) << "Failed to unload GPU module: " << status;
|
|
|
|
}
|
|
|
|
}
|
2023-07-03 06:51:45 -07:00
|
|
|
};
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
using OwnedGPUmodule =
|
|
|
|
std::unique_ptr<std::remove_pointer_t<gpuModule_t>, gpuModuleDeleter>;
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2024-01-23 03:08:25 -08:00
|
|
|
absl::StatusOr<gpuDevice_t> GetStreamDevice(gpuStream_t stream) {
|
|
|
|
gpuDevice_t device;
|
|
|
|
#ifdef JAX_GPU_HIP
|
|
|
|
int device_id = gpuGetStreamDeviceId(stream);
|
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
|
|
|
|
#else // JAX_GPU_CUDA
|
2024-08-26 17:03:27 -07:00
|
|
|
gpuContext_t context;
|
2024-01-23 03:08:25 -08:00
|
|
|
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
|
|
|
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
|
|
|
|
GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));
|
|
|
|
#endif
|
|
|
|
return device;
|
|
|
|
}
|
|
|
|
|
|
|
|
absl::StatusOr<uint32_t> MaxSharedMemoryPerBlock(gpuDevice_t device) {
|
|
|
|
int shared_optin;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
|
|
|
|
&shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
|
|
|
device));
|
|
|
|
return shared_optin;
|
|
|
|
}
|
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
|
|
|
|
uint32_t shared_mem_bytes,
|
|
|
|
std::string_view ptx,
|
|
|
|
int compute_capability) {
|
|
|
|
auto key =
|
|
|
|
std::make_tuple(kernel_name, shared_mem_bytes, ptx, compute_capability);
|
|
|
|
|
|
|
|
static absl::Mutex mutex;
|
|
|
|
static auto& module_images =
|
|
|
|
*new absl::flat_hash_map<decltype(key), std::unique_ptr<ModuleImage>>
|
|
|
|
ABSL_GUARDED_BY(mutex);
|
|
|
|
|
|
|
|
absl::MutexLock lock(&mutex);
|
|
|
|
auto it = module_images.find(key);
|
|
|
|
if (it != module_images.end()) return it->second.get();
|
|
|
|
|
2024-02-23 11:41:50 -08:00
|
|
|
#ifdef JAX_GPU_HIP // For HIP/ROCM just read the hsaco file
|
2023-09-20 18:43:16 +00:00
|
|
|
std::string result_blob;
|
2024-02-23 11:41:50 -08:00
|
|
|
std::string fname{ptx};
|
2023-09-20 18:43:16 +00:00
|
|
|
TF_RETURN_IF_ERROR(
|
|
|
|
tsl::ReadFileToString(tsl::Env::Default(), fname, &result_blob));
|
|
|
|
std::vector<uint8_t> module_image(result_blob.begin(), result_blob.end());
|
|
|
|
#else
|
2023-07-03 06:51:45 -07:00
|
|
|
// TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
|
|
|
|
int cc_major = compute_capability / 10;
|
|
|
|
int cc_minor = compute_capability % 10;
|
|
|
|
JAX_ASSIGN_OR_RETURN(
|
|
|
|
std::vector<uint8_t> module_image,
|
|
|
|
stream_executor::CompileGpuAsm(cc_major, cc_minor, ptx.data(),
|
|
|
|
stream_executor::GpuAsmOpts{}));
|
2023-09-20 18:43:16 +00:00
|
|
|
#endif
|
2023-07-03 06:51:45 -07:00
|
|
|
|
|
|
|
auto [it2, success] = module_images.insert(
|
|
|
|
{std::move(key),
|
|
|
|
std::make_unique<ModuleImage>(
|
|
|
|
std::move(kernel_name), std::move(module_image), shared_mem_bytes)});
|
|
|
|
CHECK(success);
|
|
|
|
return it2->second.get();
|
|
|
|
}
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
absl::StatusOr<float> Benchmark(gpuStream_t stream, KernelCall& kernel_call,
|
2023-07-03 06:51:45 -07:00
|
|
|
void** buffers, int num_iterations) {
|
2023-09-20 18:43:16 +00:00
|
|
|
gpuEvent_t start, stop;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuEventCreate(&start, /*Flags=*/GPU_EVENT_DEFAULT));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuEventCreate(&stop, /*Flags=*/GPU_EVENT_DEFAULT));
|
2023-07-03 06:51:45 -07:00
|
|
|
JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers)); // Warm-up.
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuEventRecord(start, stream));
|
2023-07-03 06:51:45 -07:00
|
|
|
for (int i = 0; i < num_iterations; ++i) {
|
|
|
|
JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers));
|
|
|
|
}
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuEventRecord(stop, stream));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuEventSynchronize(stop));
|
2023-07-03 06:51:45 -07:00
|
|
|
float elapsed_ms;
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuEventElapsedTime(&elapsed_ms, start, stop));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuEventDestroy(start));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuEventDestroy(stop));
|
2023-07-03 06:51:45 -07:00
|
|
|
return elapsed_ms;
|
|
|
|
}
|
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
absl::StatusOr<KernelCall*> GetKernelCall(absl::string_view opaque,
|
2023-09-20 18:43:16 +00:00
|
|
|
gpuStream_t stream, void** buffers) {
|
2023-07-03 06:51:45 -07:00
|
|
|
static absl::Mutex mutex;
|
|
|
|
static auto& kernel_calls =
|
2024-08-21 20:44:52 -07:00
|
|
|
*new absl::flat_hash_map<std::string,
|
|
|
|
absl::StatusOr<std::unique_ptr<KernelCall>>>
|
2023-07-03 06:51:45 -07:00
|
|
|
ABSL_GUARDED_BY(mutex);
|
|
|
|
|
2024-02-13 09:31:07 -08:00
|
|
|
{
|
|
|
|
// Fast path uses reader lock (as hash map look-up is relatively slow).
|
|
|
|
absl::ReaderMutexLock lock(&mutex);
|
|
|
|
auto it = kernel_calls.find(opaque);
|
2024-08-21 20:44:52 -07:00
|
|
|
if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) {
|
|
|
|
JAX_RETURN_IF_ERROR(it->second.status());
|
|
|
|
return it->second->get();
|
|
|
|
}
|
2024-02-13 09:31:07 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (opaque.empty()) {
|
|
|
|
return absl::InvalidArgumentError("Opaque data is empty.");
|
|
|
|
}
|
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
absl::MutexLock lock(&mutex);
|
|
|
|
|
2024-08-21 20:44:52 -07:00
|
|
|
auto get_kernel_call = [&]() -> absl::StatusOr<std::unique_ptr<KernelCall>> {
|
|
|
|
// The opaque data is a zlib compressed protobuf.
|
|
|
|
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2024-08-21 20:44:52 -07:00
|
|
|
jax_triton::TritonAnyKernelCall proto;
|
|
|
|
if (!proto.ParseFromString(serialized)) {
|
|
|
|
return absl::InvalidArgumentError("Failed to parse serialized data.");
|
|
|
|
}
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2024-08-21 20:44:52 -07:00
|
|
|
if (proto.has_kernel_call()) {
|
2023-07-05 11:17:43 -07:00
|
|
|
JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_,
|
2024-08-21 20:44:52 -07:00
|
|
|
KernelCall::FromProto(proto.kernel_call()));
|
|
|
|
return std::make_unique<KernelCall>(std::move(kernel_call_));
|
|
|
|
} else if (proto.has_autotuned_kernel_call()) {
|
|
|
|
JAX_ASSIGN_OR_RETURN(
|
|
|
|
AutotunedKernelCall autotuned_call,
|
|
|
|
AutotunedKernelCall::FromProto(proto.autotuned_kernel_call()));
|
|
|
|
{
|
|
|
|
JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_,
|
|
|
|
AutotunedKernelCall::Autotune(
|
|
|
|
std::move(autotuned_call), stream, buffers));
|
|
|
|
return std::make_unique<KernelCall>(std::move(kernel_call_));
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
return absl::InvalidArgumentError("Unknown kernel call type.");
|
2023-07-05 11:17:43 -07:00
|
|
|
}
|
2024-08-21 20:44:52 -07:00
|
|
|
};
|
|
|
|
|
|
|
|
// We released the reader lock, so it may have been written by another thread.
|
|
|
|
// Create a new entry if it already exists or create a new one.
|
|
|
|
auto it = kernel_calls.emplace(std::string(opaque), get_kernel_call()).first;
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2024-08-21 20:44:52 -07:00
|
|
|
JAX_RETURN_IF_ERROR(it->second.status());
|
|
|
|
return it->second->get();
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
class ModuleImage {
|
|
|
|
public:
|
|
|
|
ModuleImage(std::string_view kernel_name, std::vector<uint8_t> module_image,
|
|
|
|
uint32_t shared_mem_bytes)
|
|
|
|
: kernel_name_(kernel_name),
|
|
|
|
module_image_(std::move(module_image)),
|
|
|
|
shared_mem_bytes_(shared_mem_bytes) {}
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
absl::StatusOr<gpuFunction_t> GetFunctionForContext(gpuContext_t context) {
|
2023-07-03 06:51:45 -07:00
|
|
|
absl::MutexLock lock(&mutex_);
|
|
|
|
auto it = functions_.find(context);
|
|
|
|
if (ABSL_PREDICT_TRUE(it != functions_.end())) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
2024-08-26 17:03:27 -07:00
|
|
|
absl::Cleanup ctx_restorer = [] {
|
|
|
|
absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr));
|
|
|
|
if (!status.ok()) {
|
|
|
|
LOG(WARNING) << "Failed to pop GPU context: " << status;
|
|
|
|
}
|
|
|
|
};
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
gpuModule_t module;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data()));
|
|
|
|
modules_.push_back(OwnedGPUmodule(module, gpuModuleDeleter()));
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
gpuFunction_t function;
|
|
|
|
GPU_RETURN_IF_ERROR(
|
|
|
|
gpuModuleGetFunction(&function, module, kernel_name_.c_str()));
|
2023-07-03 06:51:45 -07:00
|
|
|
auto [_, success] = functions_.insert({context, function});
|
|
|
|
CHECK(success);
|
|
|
|
|
|
|
|
// The maximum permitted static shared memory allocation in CUDA is 48kB,
|
|
|
|
// but we can expose more to the kernel using dynamic shared memory.
|
|
|
|
constexpr int kMaxStaticSharedMemBytes = 49152;
|
|
|
|
if (shared_mem_bytes_ <= kMaxStaticSharedMemBytes) {
|
|
|
|
return function;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set up dynamic shared memory.
|
2023-09-20 18:43:16 +00:00
|
|
|
gpuDevice_t device;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));
|
2023-07-03 06:51:45 -07:00
|
|
|
|
|
|
|
int shared_optin;
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
|
|
|
|
&shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
2023-07-03 06:51:45 -07:00
|
|
|
device));
|
|
|
|
|
|
|
|
if (shared_mem_bytes_ > shared_optin) {
|
2023-08-31 16:31:42 -07:00
|
|
|
return absl::InvalidArgumentError(absl::StrFormat(
|
|
|
|
"Shared memory requested (%d b) exceeds device resources (%d b).",
|
|
|
|
shared_mem_bytes_, shared_optin));
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
if (shared_optin > kMaxStaticSharedMemBytes) {
|
2024-02-23 11:41:50 -08:00
|
|
|
#ifdef JAX_GPU_CUDA
|
|
|
|
GPU_RETURN_IF_ERROR(
|
2023-09-20 18:43:16 +00:00
|
|
|
gpuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED));
|
2024-02-23 11:41:50 -08:00
|
|
|
#endif
|
2023-07-03 06:51:45 -07:00
|
|
|
int shared_total;
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
|
2023-07-03 06:51:45 -07:00
|
|
|
&shared_total,
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
|
2023-07-03 06:51:45 -07:00
|
|
|
int shared_static;
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuFuncGetAttribute(
|
|
|
|
&shared_static, GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
|
2024-02-23 11:41:50 -08:00
|
|
|
#ifdef JAX_GPU_CUDA
|
|
|
|
GPU_RETURN_IF_ERROR(cuFuncSetAttribute(
|
2023-09-20 18:43:16 +00:00
|
|
|
function, GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
2023-07-03 06:51:45 -07:00
|
|
|
shared_optin - shared_static));
|
2024-02-23 11:41:50 -08:00
|
|
|
#endif
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
return function;
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::string kernel_name_;
|
|
|
|
std::vector<uint8_t> module_image_;
|
|
|
|
uint32_t shared_mem_bytes_;
|
|
|
|
|
|
|
|
absl::Mutex mutex_;
|
2023-09-20 18:43:16 +00:00
|
|
|
std::vector<OwnedGPUmodule> modules_ ABSL_GUARDED_BY(mutex_);
|
2024-02-23 11:41:50 -08:00
|
|
|
absl::flat_hash_map<gpuContext_t, gpuFunction_t> functions_
|
|
|
|
ABSL_GUARDED_BY(mutex_);
|
2023-07-03 06:51:45 -07:00
|
|
|
};
|
|
|
|
|
|
|
|
Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
|
|
|
|
uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
|
2024-01-19 05:55:03 -08:00
|
|
|
int compute_capability, uint32_t cluster_dim_0,
|
|
|
|
uint32_t cluster_dim_1, uint32_t cluster_dim_2)
|
2023-07-03 06:51:45 -07:00
|
|
|
: kernel_name_(std::move(kernel_name)),
|
|
|
|
block_dim_x_(num_warps * kNumThreadsPerWarp),
|
|
|
|
shared_mem_bytes_(shared_mem_bytes),
|
|
|
|
ptx_(std::move(ptx)),
|
|
|
|
ttir_(std::move(ttir)),
|
2024-01-19 05:55:03 -08:00
|
|
|
compute_capability_(compute_capability),
|
|
|
|
cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {}
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2024-01-19 05:55:03 -08:00
|
|
|
absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
|
|
|
|
void** params) {
|
2023-07-03 06:51:45 -07:00
|
|
|
if (ABSL_PREDICT_FALSE(module_image_ == nullptr)) {
|
|
|
|
JAX_ASSIGN_OR_RETURN(module_image_,
|
|
|
|
GetModuleImage(kernel_name_, shared_mem_bytes_, ptx_,
|
|
|
|
compute_capability_));
|
|
|
|
}
|
|
|
|
|
2024-01-23 03:08:25 -08:00
|
|
|
gpuContext_t context;
|
|
|
|
#ifdef JAX_GPU_HIP
|
|
|
|
int device_id = gpuGetStreamDeviceId(stream);
|
|
|
|
gpuDevice_t device;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
|
|
|
|
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
|
|
|
|
module_image_->GetFunctionForContext(context));
|
|
|
|
return JAX_AS_STATUS(gpuLaunchKernel(
|
|
|
|
kernel, grid[0], grid[1], grid[2], block_dim_x_,
|
|
|
|
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
|
|
|
|
/*extra=*/nullptr));
|
|
|
|
#else // JAX_GPU_CUDA
|
2024-02-08 13:49:03 -08:00
|
|
|
// TODO(b/324319767): A bug in CUDA prevents us from calling cuStreamGetCtx
|
|
|
|
// inside graph capture. We use cuCtxGetCurrent as a workaround here because
|
|
|
|
// context is not updated, but we should change it back to cuStreamGetCtx once
|
|
|
|
// the bug is fixed.
|
|
|
|
gpustreamCaptureStatus_t capture_status;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
|
|
|
|
if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
|
|
|
|
GPU_RETURN_IF_ERROR(gpuCtxGetCurrent(&context));
|
|
|
|
} else {
|
|
|
|
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
|
|
|
|
}
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
|
2023-07-03 06:51:45 -07:00
|
|
|
module_image_->GetFunctionForContext(context));
|
2024-01-19 05:55:03 -08:00
|
|
|
const uint32_t cluster_size =
|
|
|
|
cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2];
|
|
|
|
if (cluster_size <= 1) {
|
|
|
|
return JAX_AS_STATUS(gpuLaunchKernel(
|
|
|
|
kernel, grid[0], grid[1], grid[2], block_dim_x_,
|
|
|
|
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
|
|
|
|
/*extra=*/nullptr));
|
|
|
|
}
|
2024-03-04 05:47:41 -08:00
|
|
|
CUlaunchAttribute launch_attrs[2];
|
|
|
|
launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
|
|
launch_attrs[0].value.clusterDim.x = cluster_dims_[0];
|
|
|
|
launch_attrs[0].value.clusterDim.y = cluster_dims_[1];
|
|
|
|
launch_attrs[0].value.clusterDim.z = cluster_dims_[2];
|
|
|
|
launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
|
|
launch_attrs[1].value.clusterSchedulingPolicyPreference =
|
|
|
|
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
2024-01-11 07:51:37 -08:00
|
|
|
CUlaunchConfig launch_config = {
|
2024-01-19 05:55:03 -08:00
|
|
|
/*gridDimX=*/grid[0] * cluster_dims_[0],
|
|
|
|
/*gridDimY=*/grid[1] * cluster_dims_[1],
|
|
|
|
/*gridDimZ=*/grid[2] * cluster_dims_[2],
|
2024-01-11 07:51:37 -08:00
|
|
|
/*blockDimX=*/block_dim_x_,
|
|
|
|
/*blockDimY=*/1,
|
|
|
|
/*blockDimZ=*/1,
|
|
|
|
/*sharedMemBytes=*/shared_mem_bytes_,
|
|
|
|
/*hStream=*/stream,
|
2024-03-04 05:47:41 -08:00
|
|
|
/**attrs=*/launch_attrs,
|
|
|
|
/*numAttrs=*/2,
|
2024-01-11 07:51:37 -08:00
|
|
|
};
|
|
|
|
return JAX_AS_STATUS(
|
|
|
|
cuLaunchKernelEx(&launch_config, kernel, params, /*extra=*/nullptr));
|
|
|
|
#endif
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
/*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) {
|
|
|
|
return Kernel(proto.kernel_name(), proto.num_warps(),
|
|
|
|
proto.shared_mem_bytes(), proto.ptx(), proto.ttir(),
|
2024-01-19 05:55:03 -08:00
|
|
|
proto.compute_capability(), proto.cluster_dim_0(),
|
|
|
|
proto.cluster_dim_1(), proto.cluster_dim_2());
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
jax_triton::TritonKernel Kernel::ToProto() const {
|
|
|
|
jax_triton::TritonKernel proto;
|
|
|
|
proto.set_kernel_name(kernel_name_);
|
|
|
|
proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp);
|
|
|
|
proto.set_shared_mem_bytes(shared_mem_bytes_);
|
|
|
|
proto.set_ptx(ptx_);
|
|
|
|
proto.set_ttir(ttir_);
|
|
|
|
proto.set_compute_capability(compute_capability_);
|
2024-01-19 05:55:03 -08:00
|
|
|
proto.set_cluster_dim_0(cluster_dims_[0]);
|
|
|
|
proto.set_cluster_dim_1(cluster_dims_[1]);
|
|
|
|
proto.set_cluster_dim_2(cluster_dims_[2]);
|
2023-07-03 06:51:45 -07:00
|
|
|
return proto;
|
|
|
|
}
|
|
|
|
|
|
|
|
/*static*/ absl::StatusOr<KernelCall::Parameter>
|
|
|
|
KernelCall::Parameter::FromProto(
|
|
|
|
const jax_triton::TritonKernelCall_Parameter& proto) {
|
|
|
|
using jax_triton::TritonKernelCall_Parameter;
|
|
|
|
Parameter param;
|
|
|
|
switch (proto.value_case()) {
|
|
|
|
case TritonKernelCall_Parameter::kArray:
|
|
|
|
param.value = Array{proto.array().bytes_to_zero(),
|
|
|
|
proto.array().ptr_divisibility()};
|
|
|
|
break;
|
|
|
|
case TritonKernelCall_Parameter::kBool:
|
|
|
|
param.value = proto.bool_();
|
|
|
|
break;
|
|
|
|
case TritonKernelCall_Parameter::kI32:
|
|
|
|
param.value = proto.i32();
|
|
|
|
break;
|
|
|
|
case TritonKernelCall_Parameter::kU32:
|
|
|
|
param.value = proto.u32();
|
|
|
|
break;
|
|
|
|
case TritonKernelCall_Parameter::kI64:
|
|
|
|
param.value = proto.i64();
|
|
|
|
break;
|
|
|
|
case TritonKernelCall_Parameter::kU64:
|
|
|
|
param.value = proto.u64();
|
|
|
|
break;
|
2023-07-28 22:56:36 -07:00
|
|
|
case TritonKernelCall_Parameter::kF32:
|
|
|
|
param.value = proto.f32();
|
|
|
|
break;
|
|
|
|
case TritonKernelCall_Parameter::kF64:
|
|
|
|
param.value = proto.f64();
|
|
|
|
break;
|
2023-07-03 06:51:45 -07:00
|
|
|
default:
|
|
|
|
return absl::InvalidArgumentError("Unknown scalar parameter type.");
|
|
|
|
}
|
|
|
|
return param;
|
|
|
|
}
|
|
|
|
|
2024-01-23 03:08:25 -08:00
|
|
|
bool Kernel::CanLaunchOnDevice(gpuDevice_t device) const {
|
|
|
|
return shared_mem_bytes_ <= MaxSharedMemoryPerBlock(device).value_or(0);
|
|
|
|
}
|
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
jax_triton::TritonKernelCall_Parameter KernelCall::Parameter::ToProto() const {
|
|
|
|
jax_triton::TritonKernelCall_Parameter proto;
|
|
|
|
if (std::holds_alternative<Array>(value)) {
|
|
|
|
proto.mutable_array()->set_bytes_to_zero(
|
|
|
|
std::get<Array>(value).bytes_to_zero);
|
|
|
|
proto.mutable_array()->set_ptr_divisibility(
|
|
|
|
std::get<Array>(value).ptr_divisibility);
|
|
|
|
} else if (std::holds_alternative<bool>(value)) {
|
|
|
|
proto.set_bool_(std::get<bool>(value));
|
|
|
|
} else if (std::holds_alternative<int32_t>(value)) {
|
|
|
|
proto.set_i32(std::get<int32_t>(value));
|
|
|
|
} else if (std::holds_alternative<uint32_t>(value)) {
|
|
|
|
proto.set_u32(std::get<uint32_t>(value));
|
|
|
|
} else if (std::holds_alternative<int64_t>(value)) {
|
|
|
|
proto.set_i64(std::get<int64_t>(value));
|
2023-07-28 22:56:36 -07:00
|
|
|
} else if (std::holds_alternative<uint64_t>(value)) {
|
2023-07-03 06:51:45 -07:00
|
|
|
proto.set_u64(std::get<uint64_t>(value));
|
2023-07-28 22:56:36 -07:00
|
|
|
} else if (std::holds_alternative<float>(value)) {
|
|
|
|
proto.set_f32(std::get<float>(value));
|
|
|
|
} else {
|
|
|
|
CHECK(std::holds_alternative<double>(value));
|
|
|
|
proto.set_f64(std::get<double>(value));
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
return proto;
|
|
|
|
}
|
|
|
|
|
|
|
|
KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1,
|
|
|
|
uint32_t grid_2, std::vector<Parameter> parameters)
|
|
|
|
: kernel_(std::move(kernel)),
|
|
|
|
grid_{grid_0, grid_1, grid_2},
|
|
|
|
parameters_(std::move(parameters)) {}
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
|
2023-07-03 06:51:45 -07:00
|
|
|
std::vector<void*> params;
|
2024-12-03 10:55:13 -08:00
|
|
|
// We need an additional parameter for the scratchpad buffer.
|
|
|
|
params.reserve(parameters_.size() + 1);
|
2023-07-03 06:51:45 -07:00
|
|
|
for (size_t i = 0; i < parameters_.size(); ++i) {
|
|
|
|
const Parameter& param = parameters_[i];
|
|
|
|
if (std::holds_alternative<Parameter::Array>(param.value)) {
|
|
|
|
const auto& array = std::get<Parameter::Array>(param.value);
|
|
|
|
void*& ptr = *(buffers++);
|
2023-09-20 18:43:16 +00:00
|
|
|
auto cu_ptr = reinterpret_cast<gpuDevicePtr_t>(ptr);
|
2023-07-03 06:51:45 -07:00
|
|
|
|
|
|
|
if (ABSL_PREDICT_FALSE((array.ptr_divisibility != 0) &&
|
2023-09-20 18:43:16 +00:00
|
|
|
((size_t)cu_ptr % array.ptr_divisibility != 0))) {
|
2023-07-03 06:51:45 -07:00
|
|
|
return absl::InvalidArgumentError(
|
2023-10-18 11:56:01 -07:00
|
|
|
absl::StrFormat("Parameter %zu (%zu) is not divisible by %d.", i,
|
2023-09-20 18:43:16 +00:00
|
|
|
(size_t)ptr, array.ptr_divisibility));
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
if (array.bytes_to_zero > 0) {
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(
|
|
|
|
gpuMemsetD8Async(cu_ptr, 0, array.bytes_to_zero, stream));
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
params.push_back(&ptr);
|
|
|
|
} else {
|
|
|
|
params.push_back(const_cast<void*>(std::visit(
|
|
|
|
[](auto&& arg) { return reinterpret_cast<const void*>(&arg); },
|
|
|
|
param.value)));
|
|
|
|
}
|
|
|
|
}
|
2024-12-03 10:55:13 -08:00
|
|
|
// Triton's kernel ABI expects an additional scratchpad global memory.
|
|
|
|
// For now it is only used for on-device creation of TMA descriptors, which
|
|
|
|
// we do not use yet, so we are just replacing this argument with a null
|
|
|
|
// pointer.
|
|
|
|
// TODO: b/381242007 - Allocate a proper buffer if we want to use
|
|
|
|
// device-side TMA APIs.
|
|
|
|
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
|
|
|
|
params.push_back(&scratch_ptr);
|
2023-07-03 06:51:45 -07:00
|
|
|
|
|
|
|
return kernel_.Launch(stream, grid_, params.data());
|
|
|
|
}
|
|
|
|
|
|
|
|
/*static*/ absl::StatusOr<KernelCall> KernelCall::FromProto(
|
|
|
|
const jax_triton::TritonKernelCall& proto) {
|
|
|
|
std::vector<KernelCall::Parameter> parameters;
|
|
|
|
for (const jax_triton::TritonKernelCall_Parameter& parameter :
|
|
|
|
proto.parameters()) {
|
|
|
|
JAX_ASSIGN_OR_RETURN(Parameter p, Parameter::FromProto(parameter));
|
|
|
|
parameters.push_back(p);
|
|
|
|
}
|
|
|
|
|
|
|
|
return KernelCall(Kernel::FromProto(proto.kernel()), proto.grid_0(),
|
|
|
|
proto.grid_1(), proto.grid_2(), std::move(parameters));
|
|
|
|
}
|
|
|
|
|
|
|
|
jax_triton::TritonKernelCall KernelCall::ToProto() const {
|
|
|
|
jax_triton::TritonKernelCall proto;
|
|
|
|
*proto.mutable_kernel() = kernel_.ToProto();
|
|
|
|
proto.set_grid_0(grid_[0]);
|
|
|
|
proto.set_grid_1(grid_[1]);
|
|
|
|
proto.set_grid_2(grid_[2]);
|
|
|
|
for (const Parameter& param : parameters_) {
|
|
|
|
*proto.add_parameters() = param.ToProto();
|
|
|
|
}
|
|
|
|
return proto;
|
|
|
|
}
|
|
|
|
|
2024-01-23 03:08:25 -08:00
|
|
|
bool KernelCall::CanLaunchOnDevice(gpuDevice_t device) const {
|
|
|
|
return kernel_.CanLaunchOnDevice(device);
|
|
|
|
}
|
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
AutotunedKernelCall::AutotunedKernelCall(
|
|
|
|
std::string name, std::vector<Config> configs,
|
|
|
|
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases)
|
|
|
|
: name_(std::move(name)),
|
|
|
|
configs_(std::move(configs)),
|
|
|
|
input_output_aliases_(std::move(input_output_aliases)) {}
|
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
/*static*/ absl::StatusOr<AutotunedKernelCall> AutotunedKernelCall::FromProto(
|
2023-07-03 06:51:45 -07:00
|
|
|
const jax_triton::TritonAutotunedKernelCall& proto) {
|
|
|
|
std::vector<Config> configs;
|
|
|
|
for (const jax_triton::TritonAutotunedKernelCall_Config& config :
|
|
|
|
proto.configs()) {
|
|
|
|
JAX_ASSIGN_OR_RETURN(auto kernel_call,
|
|
|
|
KernelCall::FromProto(config.kernel_call()));
|
|
|
|
configs.push_back(Config{std::move(kernel_call), config.description()});
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases;
|
|
|
|
for (const jax_triton::TritonAutotunedKernelCall_InputOutputAlias& a :
|
|
|
|
proto.input_output_aliases()) {
|
|
|
|
input_output_aliases.push_back(std::make_tuple(
|
|
|
|
a.input_buffer_idx(), a.output_buffer_idx(), a.buffer_size_bytes()));
|
|
|
|
}
|
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
return AutotunedKernelCall(proto.name(), std::move(configs),
|
|
|
|
std::move(input_output_aliases));
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
|
|
|
|
jax_triton::TritonAutotunedKernelCall proto;
|
|
|
|
proto.set_name(name_);
|
|
|
|
for (const Config& config : configs_) {
|
|
|
|
jax_triton::TritonAutotunedKernelCall_Config* c = proto.add_configs();
|
|
|
|
*c->mutable_kernel_call() = config.kernel_call.ToProto();
|
|
|
|
c->set_description(config.description);
|
|
|
|
}
|
|
|
|
for (const auto& [input_idx, output_idx, size] : input_output_aliases_) {
|
|
|
|
jax_triton::TritonAutotunedKernelCall_InputOutputAlias* a =
|
|
|
|
proto.add_input_output_aliases();
|
|
|
|
a->set_input_buffer_idx(input_idx);
|
|
|
|
a->set_output_buffer_idx(output_idx);
|
|
|
|
a->set_buffer_size_bytes(size);
|
|
|
|
}
|
|
|
|
return proto;
|
|
|
|
}
|
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
/*static*/ absl::StatusOr<KernelCall> AutotunedKernelCall::Autotune(
|
2024-02-02 10:47:40 -08:00
|
|
|
AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) {
|
2023-07-03 06:51:45 -07:00
|
|
|
// Ensure a valid context for driver calls that don't take the stream.
|
2024-02-23 11:41:50 -08:00
|
|
|
// gpuContext_t context;
|
|
|
|
// GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
|
|
|
|
// GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
|
|
|
// absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
|
|
|
|
|
2024-04-04 05:34:28 -07:00
|
|
|
// Autotuning is not supported if the stream is in graph capture mode.
|
2024-02-02 10:47:40 -08:00
|
|
|
gpustreamCaptureStatus_t capture_status;
|
|
|
|
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
|
2024-03-01 10:28:11 -08:00
|
|
|
if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
|
|
|
|
return absl::FailedPreconditionError(
|
|
|
|
"Can't autotune Triton kernel when the stream is in graph capture "
|
|
|
|
"mode. Autotuning can rely on real data present in input buffers to "
|
|
|
|
"use them in address computation, but in graph capture mode buffers "
|
|
|
|
"can have arbitrary data");
|
2024-02-02 10:47:40 -08:00
|
|
|
}
|
2024-01-29 10:52:28 -08:00
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
// If an input aliases with an output, it will get overwritten during the
|
|
|
|
// kernel execution. If the kernel is called repeatedly, as we do during
|
|
|
|
// auto-tuning, the final result will be junk, so we take a copy of the
|
|
|
|
// input to restore after auto-tuning.
|
|
|
|
std::unordered_map<size_t, std::vector<uint8_t>> input_copies;
|
2023-07-05 11:17:43 -07:00
|
|
|
for (auto [input_idx, output_idx, size] : kernel_call.input_output_aliases_) {
|
2023-07-03 06:51:45 -07:00
|
|
|
if (buffers[input_idx] == buffers[output_idx]) {
|
|
|
|
std::vector<uint8_t> input_copy(size);
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
|
2024-02-02 10:47:40 -08:00
|
|
|
input_copy.data(),
|
2024-03-01 10:28:11 -08:00
|
|
|
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size, stream));
|
2023-07-03 06:51:45 -07:00
|
|
|
input_copies[input_idx] = std::move(input_copy);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
LOG(INFO) << "Autotuning function: " << kernel_call.name_;
|
2023-07-03 06:51:45 -07:00
|
|
|
// First run a single iteration of each to config to determine how many
|
|
|
|
// iterations to run for benchmarking.
|
|
|
|
float best = std::numeric_limits<float>::infinity();
|
2024-04-18 14:35:06 -07:00
|
|
|
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
|
|
|
|
absl::flat_hash_set<Config*> configs_to_skip;
|
2023-07-05 11:17:43 -07:00
|
|
|
for (Config& config : kernel_call.configs_) {
|
2024-04-18 14:35:06 -07:00
|
|
|
if (!config.kernel_call.CanLaunchOnDevice(device)) {
|
|
|
|
configs_to_skip.insert(&config);
|
|
|
|
continue;
|
|
|
|
}
|
2024-03-01 10:28:11 -08:00
|
|
|
JAX_ASSIGN_OR_RETURN(float t,
|
|
|
|
Benchmark(stream, config.kernel_call, buffers, 1));
|
2023-07-03 06:51:45 -07:00
|
|
|
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
|
|
|
|
best = std::min(best, t);
|
|
|
|
}
|
|
|
|
|
|
|
|
int timed_iters = std::max(static_cast<int>(kBenchmarkTimeMillis / best), 1);
|
|
|
|
if (timed_iters > 100) {
|
|
|
|
timed_iters = 100;
|
|
|
|
LOG(INFO) << "Benchmarking with 100 iters (capped at 100)";
|
|
|
|
} else {
|
|
|
|
timed_iters = std::min(timed_iters, 100);
|
|
|
|
LOG(INFO) << "Benchmarking with " << timed_iters
|
|
|
|
<< " iters (target time: " << kBenchmarkTimeMillis << " ms)";
|
|
|
|
}
|
|
|
|
|
|
|
|
best = std::numeric_limits<float>::infinity();
|
2023-07-05 11:17:43 -07:00
|
|
|
for (Config& config : kernel_call.configs_) {
|
2024-04-18 14:35:06 -07:00
|
|
|
if (configs_to_skip.contains(&config)) {
|
2024-01-23 03:08:25 -08:00
|
|
|
LOG(WARNING) << "Unable to launch autotune config on device: "
|
|
|
|
<< config.description;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
2024-03-01 10:28:11 -08:00
|
|
|
JAX_ASSIGN_OR_RETURN(
|
|
|
|
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
|
2023-07-03 06:51:45 -07:00
|
|
|
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
|
|
|
|
<< t << " ms";
|
|
|
|
|
|
|
|
if (t < best) {
|
|
|
|
LOG(INFO) << config.description << " is the new best config";
|
|
|
|
best = t;
|
2023-07-05 11:17:43 -07:00
|
|
|
std::swap(config, kernel_call.configs_[0]);
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
}
|
2024-01-23 03:08:25 -08:00
|
|
|
if (std::isinf(best)) {
|
|
|
|
LOG(WARNING) << "Finished autotuning function: " << kernel_call.name_
|
|
|
|
<< " no valid configs found.";
|
|
|
|
return absl::FailedPreconditionError("No launchable configs.");
|
|
|
|
}
|
2023-07-03 06:51:45 -07:00
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
LOG(INFO) << "Finished autotuning function: " << kernel_call.name_
|
|
|
|
<< " best config " << kernel_call.configs_[0].description;
|
2023-07-03 06:51:45 -07:00
|
|
|
|
|
|
|
// Restore aliased inputs to their original values.
|
2023-07-05 11:17:43 -07:00
|
|
|
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
|
2024-03-01 10:28:11 -08:00
|
|
|
GPU_RETURN_IF_ERROR(
|
|
|
|
gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
|
|
|
|
input_copies[input_idx].data(), size, stream));
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
2024-02-23 11:41:50 -08:00
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
// Synchronize stream to ensure copies are complete before the host copy
|
|
|
|
// is deleted.
|
2024-03-01 10:28:11 -08:00
|
|
|
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));
|
2024-02-23 11:41:50 -08:00
|
|
|
|
2023-07-05 11:17:43 -07:00
|
|
|
return std::move(kernel_call.configs_[0].kernel_call);
|
2023-07-03 06:51:45 -07:00
|
|
|
}
|
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
void TritonKernelCall(gpuStream_t stream, void** buffers, const char* opaque,
|
2023-07-03 06:51:45 -07:00
|
|
|
size_t opaque_len, XlaCustomCallStatus* status) {
|
|
|
|
absl::Status result = [=] {
|
2023-07-05 11:17:43 -07:00
|
|
|
JAX_ASSIGN_OR_RETURN(
|
|
|
|
KernelCall * kernel_call,
|
|
|
|
GetKernelCall(absl::string_view(opaque, opaque_len), stream, buffers));
|
2023-07-03 06:51:45 -07:00
|
|
|
return kernel_call->Launch(stream, buffers);
|
|
|
|
}();
|
|
|
|
if (!result.ok()) {
|
|
|
|
absl::string_view msg = result.message();
|
|
|
|
XlaCustomCallStatusSetFailure(status, msg.data(), msg.length());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace jax::JAX_GPU_NAMESPACE
|