diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 4ad764472..323f67c57 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -12,7 +12,6 @@ #include #include -#include "pybind11/buffer_info.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" @@ -541,6 +540,27 @@ class AutotunedKernelCall : public KernelCallBase { absl::Status autotune_status_; }; +absl::StatusOr ZlibUncompress(absl::string_view compressed) { + std::string data; + uLongf dest_len = 5 * compressed.size(); + while (true) { + data.resize(dest_len); + int ret = uncompress(reinterpret_cast(data.data()), &dest_len, + reinterpret_cast(compressed.data()), + compressed.size()); + if (ret == Z_OK) { + // `uncompress` overwrites `dest_len` with the uncompressed size. + data.resize(dest_len); + break; + } else if (ret == Z_BUF_ERROR) { + dest_len *= 2; // The string buffer wasn't large enough. + } else { + return absl::InvalidArgumentError("Failed to uncompress opaque data."); + } + } + return data; +} + absl::StatusOr GetKernelCall(absl::string_view opaque) { static absl::Mutex mutex; static auto& kernel_calls = @@ -552,23 +572,7 @@ absl::StatusOr GetKernelCall(absl::string_view opaque) { if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) return it->second.get(); // The opaque data is a zlib compressed protobuf. - std::string serialized; - uLongf dest_len = 5 * opaque.size(); - while (true) { - serialized.resize(dest_len); - int ret = uncompress(reinterpret_cast(serialized.data()), &dest_len, - reinterpret_cast(opaque.data()), - opaque.size()); - if (ret == Z_OK) { - // `uncompress` overwrites `dest_len` with the uncompressed size. - serialized.resize(dest_len); - break; - } else if (ret == Z_BUF_ERROR) { - dest_len *= 2; // The string buffer wasn't large enough. - } else { - return absl::InvalidArgumentError("Failed to uncompress opaque data."); - } - } + JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); TritonAnyKernelCall proto; if (!proto.ParseFromString(serialized)) { @@ -652,16 +656,12 @@ PYBIND11_MODULE(_triton, m) { py::class_(m, "TritonKernelCall") .def(py::init>()) - .def("to_proto", - [](const KernelCall& kernel_call, - py::bytes metadata) -> absl::StatusOr { - TritonAnyKernelCall proto; - *proto.mutable_kernel_call() = kernel_call.ToProto(); - py::buffer_info metadata_info(py::buffer(metadata).request()); - proto.set_metadata(metadata_info.ptr, metadata_info.size); - std::string serialized = proto.SerializeAsString(); - return py::bytes(serialized.data(), serialized.size()); - }); + .def("to_proto", [](const KernelCall& kernel_call, std::string metadata) { + TritonAnyKernelCall proto; + *proto.mutable_kernel_call() = kernel_call.ToProto(); + proto.set_metadata(std::move(metadata)); + return py::bytes(proto.SerializeAsString()); + }); py::class_(m, "TritonAutotunedKernelCall") .def(py::init<>([](std::string name, @@ -679,14 +679,11 @@ PYBIND11_MODULE(_triton, m) { std::move(input_output_aliases)); })) .def("to_proto", - [](const AutotunedKernelCall& kernel_call, - py::bytes metadata) -> absl::StatusOr { + [](const AutotunedKernelCall& kernel_call, std::string metadata) { TritonAnyKernelCall proto; *proto.mutable_autotuned_kernel_call() = kernel_call.ToProto(); - py::buffer_info metadata_info(py::buffer(metadata).request()); - proto.set_metadata(metadata_info.ptr, metadata_info.size); - std::string serialized = proto.SerializeAsString(); - return py::bytes(serialized.data(), serialized.size()); + proto.set_metadata(std::move(metadata)); + return py::bytes(proto.SerializeAsString()); }); m.def("get_custom_call", @@ -701,6 +698,17 @@ PYBIND11_MODULE(_triton, m) { &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); return major * 10 + minor; }); + + m.def( + "get_serialized_metadata", + [](absl::string_view opaque) -> absl::StatusOr { + JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); + TritonAnyKernelCall proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError("Failed to parse serialized data."); + } + return py::bytes(proto.metadata()); + }); } } // namespace jax_triton diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index 40cc69a05..d813953f6 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -26,5 +26,6 @@ try: create_scalar_parameter = _triton.create_scalar_parameter get_compute_capability = _triton.get_compute_capability get_custom_call = _triton.get_custom_call + get_serialized_metadata = _triton.get_serialized_metadata except ImportError: _triton = None