Add get_serialized_metadata function to retrieve metadata from op's opaque data.

PiperOrigin-RevId: 544608895
This commit is contained in:
Chris Jones 2023-06-30 03:22:48 -07:00 committed by jax authors
parent 2575307c04
commit 3f9da19c63
2 changed files with 43 additions and 34 deletions

View File

@ -12,7 +12,6 @@
#include <variant>
#include <vector>
#include "pybind11/buffer_info.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
@ -541,6 +540,27 @@ class AutotunedKernelCall : public KernelCallBase {
absl::Status autotune_status_;
};
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
std::string data;
uLongf dest_len = 5 * compressed.size();
while (true) {
data.resize(dest_len);
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
reinterpret_cast<const Bytef*>(compressed.data()),
compressed.size());
if (ret == Z_OK) {
// `uncompress` overwrites `dest_len` with the uncompressed size.
data.resize(dest_len);
break;
} else if (ret == Z_BUF_ERROR) {
dest_len *= 2; // The string buffer wasn't large enough.
} else {
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
}
}
return data;
}
absl::StatusOr<KernelCallBase*> GetKernelCall(absl::string_view opaque) {
static absl::Mutex mutex;
static auto& kernel_calls =
@ -552,23 +572,7 @@ absl::StatusOr<KernelCallBase*> GetKernelCall(absl::string_view opaque) {
if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) return it->second.get();
// The opaque data is a zlib compressed protobuf.
std::string serialized;
uLongf dest_len = 5 * opaque.size();
while (true) {
serialized.resize(dest_len);
int ret = uncompress(reinterpret_cast<Bytef*>(serialized.data()), &dest_len,
reinterpret_cast<const Bytef*>(opaque.data()),
opaque.size());
if (ret == Z_OK) {
// `uncompress` overwrites `dest_len` with the uncompressed size.
serialized.resize(dest_len);
break;
} else if (ret == Z_BUF_ERROR) {
dest_len *= 2; // The string buffer wasn't large enough.
} else {
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
}
}
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
@ -652,15 +656,11 @@ PYBIND11_MODULE(_triton, m) {
py::class_<KernelCall>(m, "TritonKernelCall")
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
std::vector<KernelCall::Parameter>>())
.def("to_proto",
[](const KernelCall& kernel_call,
py::bytes metadata) -> absl::StatusOr<py::bytes> {
.def("to_proto", [](const KernelCall& kernel_call, std::string metadata) {
TritonAnyKernelCall proto;
*proto.mutable_kernel_call() = kernel_call.ToProto();
py::buffer_info metadata_info(py::buffer(metadata).request());
proto.set_metadata(metadata_info.ptr, metadata_info.size);
std::string serialized = proto.SerializeAsString();
return py::bytes(serialized.data(), serialized.size());
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});
py::class_<AutotunedKernelCall>(m, "TritonAutotunedKernelCall")
@ -679,14 +679,11 @@ PYBIND11_MODULE(_triton, m) {
std::move(input_output_aliases));
}))
.def("to_proto",
[](const AutotunedKernelCall& kernel_call,
py::bytes metadata) -> absl::StatusOr<py::bytes> {
[](const AutotunedKernelCall& kernel_call, std::string metadata) {
TritonAnyKernelCall proto;
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
py::buffer_info metadata_info(py::buffer(metadata).request());
proto.set_metadata(metadata_info.ptr, metadata_info.size);
std::string serialized = proto.SerializeAsString();
return py::bytes(serialized.data(), serialized.size());
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});
m.def("get_custom_call",
@ -701,6 +698,17 @@ PYBIND11_MODULE(_triton, m) {
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
return major * 10 + minor;
});
m.def(
"get_serialized_metadata",
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return py::bytes(proto.metadata());
});
}
} // namespace jax_triton

View File

@ -26,5 +26,6 @@ try:
create_scalar_parameter = _triton.create_scalar_parameter
get_compute_capability = _triton.get_compute_capability
get_custom_call = _triton.get_custom_call
get_serialized_metadata = _triton.get_serialized_metadata
except ImportError:
_triton = None