mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax_triton] Add user-specified name
field to serialized format.
PiperOrigin-RevId: 557415723
This commit is contained in:
parent
c7e8b81a74
commit
4ac2bdc2b1
@ -35,7 +35,13 @@ from jax._src import util
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib import hlo_helpers
|
||||
from jax._src.lib import version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import indexing
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.state import AbstractRef
|
||||
from jax._src.state import discharge
|
||||
from jax._src.state import primitives as sp
|
||||
@ -47,11 +53,6 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.lib import xla_client as xc
|
||||
import jax.numpy as jnp
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import indexing
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax_triton import triton_lib
|
||||
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
|
||||
from jax_triton.triton_lib import get_triton_type
|
||||
@ -1687,12 +1688,15 @@ def pallas_call_lowering(
|
||||
if triton_params is None:
|
||||
triton_params = {}
|
||||
serialized_metadata = triton_params.get("serialized_metadata", b"")
|
||||
|
||||
if version >= (0, 4, 15):
|
||||
kernel_call_proto = kernel_call.to_proto(name, serialized_metadata)
|
||||
else:
|
||||
kernel_call_proto = kernel_call.to_proto(serialized_metadata)
|
||||
return hlo_helpers.custom_call(
|
||||
call_target_name=name,
|
||||
out_types=out_types,
|
||||
operands=in_nodes,
|
||||
backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
|
||||
backend_config=zlib.compress(kernel_call_proto),
|
||||
operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in),
|
||||
result_layouts=triton_lib.avals_to_layouts(ctx.avals_out),
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
|
@ -390,6 +390,7 @@ cc_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":triton_utils",
|
||||
"//jaxlib/gpu:triton_cc_proto",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/stream_executor/cuda:cudart_stub",
|
||||
@ -403,6 +404,20 @@ cc_library(
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "triton_utils",
|
||||
srcs = ["//jaxlib/gpu:triton_utils.cc"],
|
||||
hdrs = ["//jaxlib/gpu:triton_utils.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib/gpu:triton_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@zlib",
|
||||
],
|
||||
)
|
||||
@ -426,6 +441,7 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":triton_kernels",
|
||||
":triton_utils",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"//jaxlib/gpu:triton_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -50,6 +50,8 @@ exports_files(srcs = [
|
||||
"triton.cc",
|
||||
"triton_kernels.cc",
|
||||
"triton_kernels.h",
|
||||
"triton_utils.cc",
|
||||
"triton_utils.h",
|
||||
"vendor.h",
|
||||
])
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/triton.pb.h"
|
||||
#include "jaxlib/gpu/triton_kernels.h"
|
||||
#include "jaxlib/gpu/triton_utils.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "pybind11_abseil/status_casters.h" // IWYU pragma: keep
|
||||
@ -80,9 +81,11 @@ PYBIND11_MODULE(_triton, m) {
|
||||
py::class_<KernelCall>(m, "TritonKernelCall")
|
||||
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
|
||||
std::vector<KernelCall::Parameter>>())
|
||||
.def("to_proto", [](const KernelCall& kernel_call, std::string metadata) {
|
||||
.def("to_proto", [](const KernelCall& kernel_call, std::string name,
|
||||
std::string metadata) {
|
||||
jax_triton::TritonAnyKernelCall proto;
|
||||
*proto.mutable_kernel_call() = kernel_call.ToProto();
|
||||
proto.set_name(std::move(name));
|
||||
proto.set_metadata(std::move(metadata));
|
||||
return py::bytes(proto.SerializeAsString());
|
||||
});
|
||||
@ -102,13 +105,14 @@ PYBIND11_MODULE(_triton, m) {
|
||||
std::move(name), std::move(configs),
|
||||
std::move(input_output_aliases));
|
||||
}))
|
||||
.def("to_proto",
|
||||
[](const AutotunedKernelCall& kernel_call, std::string metadata) {
|
||||
jax_triton::TritonAnyKernelCall proto;
|
||||
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
|
||||
proto.set_metadata(std::move(metadata));
|
||||
return py::bytes(proto.SerializeAsString());
|
||||
});
|
||||
.def("to_proto", [](const AutotunedKernelCall& kernel_call,
|
||||
std::string name, std::string metadata) {
|
||||
jax_triton::TritonAnyKernelCall proto;
|
||||
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
|
||||
proto.set_name(std::move(name));
|
||||
proto.set_metadata(std::move(metadata));
|
||||
return py::bytes(proto.SerializeAsString());
|
||||
});
|
||||
|
||||
m.def("get_custom_call",
|
||||
[] { return EncapsulateFunction(&TritonKernelCall); });
|
||||
@ -123,16 +127,12 @@ PYBIND11_MODULE(_triton, m) {
|
||||
return major * 10 + minor;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"get_serialized_metadata",
|
||||
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
|
||||
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
|
||||
jax_triton::TritonAnyKernelCall proto;
|
||||
if (!proto.ParseFromString(serialized)) {
|
||||
return absl::InvalidArgumentError("Failed to parse serialized data.");
|
||||
}
|
||||
return py::bytes(proto.metadata());
|
||||
});
|
||||
m.def("get_serialized_metadata",
|
||||
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
|
||||
JAX_ASSIGN_OR_RETURN(std::string metadata,
|
||||
GetTritonKernelCallSerializedMetadata(opaque));
|
||||
return py::bytes(metadata);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jax::JAX_GPU_NAMESPACE
|
||||
|
@ -3,7 +3,7 @@ syntax = "proto3";
|
||||
package jax_triton;
|
||||
|
||||
message TritonKernel {
|
||||
string kernel_name = 1;
|
||||
string kernel_name = 1; // Kernel function name within module.
|
||||
uint32 num_warps = 2;
|
||||
uint32 shared_mem_bytes = 3;
|
||||
string ptx = 4;
|
||||
@ -49,7 +49,7 @@ message TritonAutotunedKernelCall {
|
||||
uint64 buffer_size_bytes = 3;
|
||||
}
|
||||
|
||||
string name = 1;
|
||||
string name = 1; // Name used in auto-tuning log messages.
|
||||
repeated Config configs = 2;
|
||||
repeated InputOutputAlias input_output_aliases = 3;
|
||||
}
|
||||
@ -60,4 +60,5 @@ message TritonAnyKernelCall {
|
||||
TritonAutotunedKernelCall autotuned_kernel_call = 2;
|
||||
}
|
||||
bytes metadata = 3;
|
||||
string name = 4; // User assigned name.
|
||||
}
|
||||
|
@ -1,7 +1,5 @@
|
||||
#include "jaxlib/gpu/triton_kernels.h"
|
||||
|
||||
#include <zlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
@ -23,6 +21,7 @@
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/triton.pb.h"
|
||||
#include "jaxlib/gpu/triton_utils.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
#include "xla/stream_executor/gpu/asm_compiler.h"
|
||||
@ -519,25 +518,4 @@ void TritonKernelCall(CUstream stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
|
||||
std::string data;
|
||||
uLongf dest_len = 5 * compressed.size();
|
||||
while (true) {
|
||||
data.resize(dest_len);
|
||||
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
|
||||
reinterpret_cast<const Bytef*>(compressed.data()),
|
||||
compressed.size());
|
||||
if (ret == Z_OK) {
|
||||
// `uncompress` overwrites `dest_len` with the uncompressed size.
|
||||
data.resize(dest_len);
|
||||
break;
|
||||
} else if (ret == Z_BUF_ERROR) {
|
||||
dest_len *= 2; // The string buffer wasn't large enough.
|
||||
} else {
|
||||
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
} // namespace jax::JAX_GPU_NAMESPACE
|
||||
|
@ -8,10 +8,8 @@
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/triton.pb.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
@ -101,8 +99,6 @@ class AutotunedKernelCall {
|
||||
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases_;
|
||||
};
|
||||
|
||||
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed);
|
||||
|
||||
} // namespace jax::JAX_GPU_NAMESPACE
|
||||
|
||||
#endif // JAXLIB_GPU_TRITON_H_
|
||||
|
55
jaxlib/gpu/triton_utils.cc
Normal file
55
jaxlib/gpu/triton_utils.cc
Normal file
@ -0,0 +1,55 @@
|
||||
#include "jaxlib/gpu/triton_utils.h"
|
||||
|
||||
#include <zlib.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/triton.pb.h"
|
||||
|
||||
namespace jax::JAX_GPU_NAMESPACE {
|
||||
|
||||
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
|
||||
std::string data;
|
||||
uLongf dest_len = 5 * compressed.size();
|
||||
while (true) {
|
||||
data.resize(dest_len);
|
||||
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
|
||||
reinterpret_cast<const Bytef*>(compressed.data()),
|
||||
compressed.size());
|
||||
if (ret == Z_OK) {
|
||||
// `uncompress` overwrites `dest_len` with the uncompressed size.
|
||||
data.resize(dest_len);
|
||||
break;
|
||||
} else if (ret == Z_BUF_ERROR) {
|
||||
dest_len *= 2; // The string buffer wasn't large enough.
|
||||
} else {
|
||||
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> GetTritonKernelCallName(absl::string_view opaque) {
|
||||
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
|
||||
jax_triton::TritonAnyKernelCall proto;
|
||||
if (!proto.ParseFromString(serialized)) {
|
||||
return absl::InvalidArgumentError("Failed to parse serialized data.");
|
||||
}
|
||||
return proto.name();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> GetTritonKernelCallSerializedMetadata(
|
||||
absl::string_view opaque) {
|
||||
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
|
||||
jax_triton::TritonAnyKernelCall proto;
|
||||
if (!proto.ParseFromString(serialized)) {
|
||||
return absl::InvalidArgumentError("Failed to parse serialized data.");
|
||||
}
|
||||
return proto.metadata();
|
||||
}
|
||||
|
||||
} // namespace jax::JAX_GPU_NAMESPACE
|
20
jaxlib/gpu/triton_utils.h
Normal file
20
jaxlib/gpu/triton_utils.h
Normal file
@ -0,0 +1,20 @@
|
||||
#ifndef JAXLIB_GPU_TRITON_UTILS_H_
|
||||
#define JAXLIB_GPU_TRITON_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax::JAX_GPU_NAMESPACE {
|
||||
|
||||
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed);
|
||||
absl::StatusOr<std::string> GetTritonKernelCallName(absl::string_view opaque);
|
||||
absl::StatusOr<std::string> GetTritonKernelCallSerializedMetadata(
|
||||
absl::string_view opaque);
|
||||
|
||||
} // namespace jax::JAX_GPU_NAMESPACE
|
||||
|
||||
#endif // JAXLIB_GPU_TRITON_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user