mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add get_serialized_metadata
function to retrieve metadata from op's opaque data.
PiperOrigin-RevId: 544608895
This commit is contained in:
parent
2575307c04
commit
3f9da19c63
@ -12,7 +12,6 @@
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "pybind11/buffer_info.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
@ -541,6 +540,27 @@ class AutotunedKernelCall : public KernelCallBase {
|
||||
absl::Status autotune_status_;
|
||||
};
|
||||
|
||||
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
|
||||
std::string data;
|
||||
uLongf dest_len = 5 * compressed.size();
|
||||
while (true) {
|
||||
data.resize(dest_len);
|
||||
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
|
||||
reinterpret_cast<const Bytef*>(compressed.data()),
|
||||
compressed.size());
|
||||
if (ret == Z_OK) {
|
||||
// `uncompress` overwrites `dest_len` with the uncompressed size.
|
||||
data.resize(dest_len);
|
||||
break;
|
||||
} else if (ret == Z_BUF_ERROR) {
|
||||
dest_len *= 2; // The string buffer wasn't large enough.
|
||||
} else {
|
||||
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
absl::StatusOr<KernelCallBase*> GetKernelCall(absl::string_view opaque) {
|
||||
static absl::Mutex mutex;
|
||||
static auto& kernel_calls =
|
||||
@ -552,23 +572,7 @@ absl::StatusOr<KernelCallBase*> GetKernelCall(absl::string_view opaque) {
|
||||
if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) return it->second.get();
|
||||
|
||||
// The opaque data is a zlib compressed protobuf.
|
||||
std::string serialized;
|
||||
uLongf dest_len = 5 * opaque.size();
|
||||
while (true) {
|
||||
serialized.resize(dest_len);
|
||||
int ret = uncompress(reinterpret_cast<Bytef*>(serialized.data()), &dest_len,
|
||||
reinterpret_cast<const Bytef*>(opaque.data()),
|
||||
opaque.size());
|
||||
if (ret == Z_OK) {
|
||||
// `uncompress` overwrites `dest_len` with the uncompressed size.
|
||||
serialized.resize(dest_len);
|
||||
break;
|
||||
} else if (ret == Z_BUF_ERROR) {
|
||||
dest_len *= 2; // The string buffer wasn't large enough.
|
||||
} else {
|
||||
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
|
||||
}
|
||||
}
|
||||
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
|
||||
|
||||
TritonAnyKernelCall proto;
|
||||
if (!proto.ParseFromString(serialized)) {
|
||||
@ -652,15 +656,11 @@ PYBIND11_MODULE(_triton, m) {
|
||||
py::class_<KernelCall>(m, "TritonKernelCall")
|
||||
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
|
||||
std::vector<KernelCall::Parameter>>())
|
||||
.def("to_proto",
|
||||
[](const KernelCall& kernel_call,
|
||||
py::bytes metadata) -> absl::StatusOr<py::bytes> {
|
||||
.def("to_proto", [](const KernelCall& kernel_call, std::string metadata) {
|
||||
TritonAnyKernelCall proto;
|
||||
*proto.mutable_kernel_call() = kernel_call.ToProto();
|
||||
py::buffer_info metadata_info(py::buffer(metadata).request());
|
||||
proto.set_metadata(metadata_info.ptr, metadata_info.size);
|
||||
std::string serialized = proto.SerializeAsString();
|
||||
return py::bytes(serialized.data(), serialized.size());
|
||||
proto.set_metadata(std::move(metadata));
|
||||
return py::bytes(proto.SerializeAsString());
|
||||
});
|
||||
|
||||
py::class_<AutotunedKernelCall>(m, "TritonAutotunedKernelCall")
|
||||
@ -679,14 +679,11 @@ PYBIND11_MODULE(_triton, m) {
|
||||
std::move(input_output_aliases));
|
||||
}))
|
||||
.def("to_proto",
|
||||
[](const AutotunedKernelCall& kernel_call,
|
||||
py::bytes metadata) -> absl::StatusOr<py::bytes> {
|
||||
[](const AutotunedKernelCall& kernel_call, std::string metadata) {
|
||||
TritonAnyKernelCall proto;
|
||||
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
|
||||
py::buffer_info metadata_info(py::buffer(metadata).request());
|
||||
proto.set_metadata(metadata_info.ptr, metadata_info.size);
|
||||
std::string serialized = proto.SerializeAsString();
|
||||
return py::bytes(serialized.data(), serialized.size());
|
||||
proto.set_metadata(std::move(metadata));
|
||||
return py::bytes(proto.SerializeAsString());
|
||||
});
|
||||
|
||||
m.def("get_custom_call",
|
||||
@ -701,6 +698,17 @@ PYBIND11_MODULE(_triton, m) {
|
||||
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
|
||||
return major * 10 + minor;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"get_serialized_metadata",
|
||||
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
|
||||
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
|
||||
TritonAnyKernelCall proto;
|
||||
if (!proto.ParseFromString(serialized)) {
|
||||
return absl::InvalidArgumentError("Failed to parse serialized data.");
|
||||
}
|
||||
return py::bytes(proto.metadata());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jax_triton
|
||||
|
@ -26,5 +26,6 @@ try:
|
||||
create_scalar_parameter = _triton.create_scalar_parameter
|
||||
get_compute_capability = _triton.get_compute_capability
|
||||
get_custom_call = _triton.get_custom_call
|
||||
get_serialized_metadata = _triton.get_serialized_metadata
|
||||
except ImportError:
|
||||
_triton = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user