2023-05-24 12:09:57 -07:00
|
|
|
#include <cstdint>
|
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
|
|
|
#include <string_view>
|
2023-06-22 04:56:43 -07:00
|
|
|
#include <tuple>
|
2023-05-24 12:09:57 -07:00
|
|
|
#include <vector>
|
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
#include "nanobind/nanobind.h"
|
|
|
|
#include "nanobind/stl/pair.h"
|
|
|
|
#include "nanobind/stl/string.h"
|
|
|
|
#include "nanobind/stl/string_view.h"
|
|
|
|
#include "nanobind/stl/tuple.h"
|
|
|
|
#include "nanobind/stl/vector.h"
|
2023-05-24 12:09:57 -07:00
|
|
|
#include "absl/status/statusor.h"
|
2023-08-24 16:06:18 -07:00
|
|
|
#include "jaxlib/absl_status_casters.h"
|
2023-05-25 10:25:11 -07:00
|
|
|
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
2023-06-22 04:56:43 -07:00
|
|
|
#include "jaxlib/gpu/triton.pb.h"
|
2023-07-03 06:51:45 -07:00
|
|
|
#include "jaxlib/gpu/triton_kernels.h"
|
2023-08-16 02:53:05 -07:00
|
|
|
#include "jaxlib/gpu/triton_utils.h"
|
2023-05-24 12:09:57 -07:00
|
|
|
#include "jaxlib/gpu/vendor.h"
|
2023-08-24 16:06:18 -07:00
|
|
|
#include "jaxlib/kernel_nanobind_helpers.h"
|
2023-05-24 12:09:57 -07:00
|
|
|
|
2023-09-20 18:43:16 +00:00
|
|
|
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
|
2023-05-24 12:09:57 -07:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
namespace nb = nanobind;
|
2023-05-24 12:09:57 -07:00
|
|
|
|
2023-05-25 10:25:11 -07:00
|
|
|
namespace jax::JAX_GPU_NAMESPACE {
|
2023-05-24 12:09:57 -07:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
NB_MODULE(_triton, m) {
|
|
|
|
nb::class_<Kernel>(m, "TritonKernel")
|
|
|
|
.def(nb::init<std::string, uint32_t, uint32_t, std::string, std::string,
|
2024-01-19 05:55:03 -08:00
|
|
|
int, uint32_t, uint32_t, uint32_t>());
|
2023-06-22 04:56:43 -07:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
nb::class_<KernelCall::Parameter>(m, "TritonParameter");
|
2023-06-22 04:56:43 -07:00
|
|
|
|
|
|
|
m.def("create_array_parameter",
|
|
|
|
[](size_t bytes_to_zero, size_t ptr_divisibility) {
|
|
|
|
return KernelCall::Parameter{
|
|
|
|
KernelCall::Parameter::Array{bytes_to_zero, ptr_divisibility}};
|
|
|
|
});
|
|
|
|
|
|
|
|
m.def("create_scalar_parameter",
|
2023-08-24 16:06:18 -07:00
|
|
|
ValueOrThrowWrapper([](bool value, std::string_view dtype)
|
|
|
|
-> absl::StatusOr<KernelCall::Parameter> {
|
2023-07-28 22:56:36 -07:00
|
|
|
if ((dtype == "i1") || (dtype == "B")) {
|
2023-08-24 16:06:18 -07:00
|
|
|
return KernelCall::Parameter{value};
|
2023-06-22 04:56:43 -07:00
|
|
|
} else {
|
|
|
|
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
|
|
|
dtype.data());
|
|
|
|
}
|
2023-08-24 16:06:18 -07:00
|
|
|
}));
|
2023-06-22 04:56:43 -07:00
|
|
|
|
|
|
|
m.def("create_scalar_parameter",
|
2023-08-24 16:06:18 -07:00
|
|
|
ValueOrThrowWrapper([](nb::int_ value, std::string_view dtype)
|
|
|
|
-> absl::StatusOr<KernelCall::Parameter> {
|
2023-06-22 04:56:43 -07:00
|
|
|
if (dtype == "i32") {
|
|
|
|
return KernelCall::Parameter{static_cast<int32_t>(value)};
|
|
|
|
} else if (dtype == "u32") {
|
|
|
|
return KernelCall::Parameter{static_cast<uint32_t>(value)};
|
|
|
|
} else if (dtype == "i64") {
|
|
|
|
return KernelCall::Parameter{static_cast<int64_t>(value)};
|
|
|
|
} else if (dtype == "u64") {
|
|
|
|
return KernelCall::Parameter{static_cast<uint64_t>(value)};
|
|
|
|
} else {
|
|
|
|
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
|
|
|
dtype.data());
|
|
|
|
}
|
2023-08-24 16:06:18 -07:00
|
|
|
}));
|
2023-06-22 04:56:43 -07:00
|
|
|
|
2023-07-28 22:56:36 -07:00
|
|
|
m.def("create_scalar_parameter",
|
2023-08-24 16:06:18 -07:00
|
|
|
ValueOrThrowWrapper([](double value, std::string_view dtype)
|
|
|
|
-> absl::StatusOr<KernelCall::Parameter> {
|
2023-07-28 22:56:36 -07:00
|
|
|
if (dtype == "fp32") {
|
|
|
|
return KernelCall::Parameter{static_cast<float>(value)};
|
|
|
|
} else if (dtype == "fp64") {
|
|
|
|
return KernelCall::Parameter{static_cast<double>(value)};
|
|
|
|
} else {
|
|
|
|
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
|
|
|
dtype.data());
|
|
|
|
}
|
2023-08-24 16:06:18 -07:00
|
|
|
}));
|
2023-07-28 22:56:36 -07:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
nb::class_<KernelCall>(m, "TritonKernelCall")
|
|
|
|
.def(nb::init<Kernel, uint32_t, uint32_t, uint32_t,
|
2023-06-22 04:56:43 -07:00
|
|
|
std::vector<KernelCall::Parameter>>())
|
2023-08-16 02:53:05 -07:00
|
|
|
.def("to_proto", [](const KernelCall& kernel_call, std::string name,
|
2023-08-24 16:06:18 -07:00
|
|
|
nb::bytes metadata) {
|
2023-07-03 06:51:45 -07:00
|
|
|
jax_triton::TritonAnyKernelCall proto;
|
2023-06-30 03:22:48 -07:00
|
|
|
*proto.mutable_kernel_call() = kernel_call.ToProto();
|
2023-08-16 02:53:05 -07:00
|
|
|
proto.set_name(std::move(name));
|
2023-08-24 16:06:18 -07:00
|
|
|
proto.set_metadata(metadata.c_str(), metadata.size());
|
|
|
|
std::string s = proto.SerializeAsString();
|
|
|
|
return nb::bytes(s.c_str(), s.size());
|
2023-06-30 03:22:48 -07:00
|
|
|
});
|
2023-06-22 04:56:43 -07:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
nb::class_<AutotunedKernelCall>(m, "TritonAutotunedKernelCall")
|
|
|
|
.def("__init__",
|
|
|
|
[](AutotunedKernelCall* call, std::string name,
|
|
|
|
std::vector<std::pair<KernelCall, std::string>>
|
|
|
|
calls_and_descriptions,
|
|
|
|
std::vector<std::tuple<size_t, size_t, size_t>>
|
|
|
|
input_output_aliases) {
|
|
|
|
std::vector<AutotunedKernelCall::Config> configs;
|
|
|
|
configs.reserve(calls_and_descriptions.size());
|
|
|
|
for (auto& [kernel_call, desc] : calls_and_descriptions) {
|
|
|
|
configs.push_back({std::move(kernel_call), std::move(desc)});
|
|
|
|
}
|
|
|
|
new (call) AutotunedKernelCall(std::move(name), std::move(configs),
|
|
|
|
std::move(input_output_aliases));
|
|
|
|
})
|
2023-08-16 02:53:05 -07:00
|
|
|
.def("to_proto", [](const AutotunedKernelCall& kernel_call,
|
2023-08-24 16:06:18 -07:00
|
|
|
std::string name, nb::bytes metadata) {
|
2023-08-16 02:53:05 -07:00
|
|
|
jax_triton::TritonAnyKernelCall proto;
|
|
|
|
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
|
|
|
|
proto.set_name(std::move(name));
|
2023-08-24 16:06:18 -07:00
|
|
|
proto.set_metadata(metadata.c_str(), metadata.size());
|
|
|
|
std::string s = proto.SerializeAsString();
|
|
|
|
return nb::bytes(s.c_str(), s.size());
|
2023-08-16 02:53:05 -07:00
|
|
|
});
|
2023-05-24 12:09:57 -07:00
|
|
|
|
2023-06-21 10:37:15 -07:00
|
|
|
m.def("get_custom_call",
|
2023-07-03 06:51:45 -07:00
|
|
|
[] { return EncapsulateFunction(&TritonKernelCall); });
|
2023-05-24 12:09:57 -07:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
m.def("get_compute_capability",
|
|
|
|
ValueOrThrowWrapper([](int device) -> absl::StatusOr<int> {
|
|
|
|
int major, minor;
|
2023-09-20 18:43:16 +00:00
|
|
|
GPU_RETURN_IF_ERROR(gpuInit(device));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
|
|
|
|
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
|
|
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
|
|
|
|
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
|
2023-08-24 16:06:18 -07:00
|
|
|
return major * 10 + minor;
|
|
|
|
}));
|
2023-06-30 03:22:48 -07:00
|
|
|
|
2024-07-25 19:02:33 +00:00
|
|
|
m.def(
|
|
|
|
"get_arch_details",
|
|
|
|
ValueOrThrowWrapper([](int device) -> absl::StatusOr<absl::string_view> {
|
|
|
|
#ifdef JAX_GPU_HIP
|
|
|
|
hipDeviceProp_t prop;
|
|
|
|
hipGetDeviceProperties(&prop, 0);
|
|
|
|
return prop.gcnArchName;
|
|
|
|
#else
|
|
|
|
return absl::UnimplementedError("Not a HIP GPU");
|
|
|
|
#endif
|
|
|
|
}));
|
|
|
|
|
2023-08-16 02:53:05 -07:00
|
|
|
m.def("get_serialized_metadata",
|
2023-08-24 16:06:18 -07:00
|
|
|
ValueOrThrowWrapper(
|
2023-08-25 07:30:44 -07:00
|
|
|
[](nb::bytes opaque) -> absl::StatusOr<nb::bytes> {
|
2023-08-24 16:06:18 -07:00
|
|
|
JAX_ASSIGN_OR_RETURN(
|
|
|
|
std::string metadata,
|
2023-08-25 07:30:44 -07:00
|
|
|
GetTritonKernelCallSerializedMetadata(
|
|
|
|
absl::string_view(opaque.c_str(), opaque.size())));
|
2023-08-24 16:06:18 -07:00
|
|
|
return nb::bytes(metadata.c_str(), metadata.size());
|
|
|
|
}));
|
2023-05-24 12:09:57 -07:00
|
|
|
}
|
|
|
|
|
2023-07-03 06:51:45 -07:00
|
|
|
} // namespace jax::JAX_GPU_NAMESPACE
|