[jax_triton] Add user-specified name field to serialized format.

PiperOrigin-RevId: 557415723
This commit is contained in:
Chris Jones 2023-08-16 02:53:05 -07:00 committed by jax authors
parent c7e8b81a74
commit 4ac2bdc2b1
9 changed files with 126 additions and 54 deletions

View File

@ -35,7 +35,13 @@ from jax._src import util
from jax._src.lax.control_flow import for_loop
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import hlo_helpers
from jax._src.lib import version
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.state import AbstractRef
from jax._src.state import discharge
from jax._src.state import primitives as sp
@ -47,11 +53,6 @@ from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.lib import xla_client as xc
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas import primitives
from jax._src.pallas import indexing
from jax._src.pallas import utils as pallas_utils
from jax_triton import triton_lib
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
from jax_triton.triton_lib import get_triton_type
@ -1687,12 +1688,15 @@ def pallas_call_lowering(
if triton_params is None:
triton_params = {}
serialized_metadata = triton_params.get("serialized_metadata", b"")
if version >= (0, 4, 15):
kernel_call_proto = kernel_call.to_proto(name, serialized_metadata)
else:
kernel_call_proto = kernel_call.to_proto(serialized_metadata)
return hlo_helpers.custom_call(
call_target_name=name,
out_types=out_types,
operands=in_nodes,
backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
backend_config=zlib.compress(kernel_call_proto),
operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in),
result_layouts=triton_lib.avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),

View File

@ -390,6 +390,7 @@ cc_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":triton_utils",
"//jaxlib/gpu:triton_cc_proto",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/cuda:cudart_stub",
@ -403,6 +404,20 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "triton_utils",
srcs = ["//jaxlib/gpu:triton_utils.cc"],
hdrs = ["//jaxlib/gpu:triton_utils.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib/gpu:triton_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@zlib",
],
)
@ -426,6 +441,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":triton_kernels",
":triton_utils",
"//jaxlib:kernel_pybind11_helpers",
"//jaxlib/gpu:triton_cc_proto",
"@com_google_absl//absl/status:statusor",

View File

@ -50,6 +50,8 @@ exports_files(srcs = [
"triton.cc",
"triton_kernels.cc",
"triton_kernels.h",
"triton_utils.cc",
"triton_utils.h",
"vendor.h",
])

View File

@ -12,6 +12,7 @@
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/triton_kernels.h"
#include "jaxlib/gpu/triton_utils.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "pybind11_abseil/status_casters.h" // IWYU pragma: keep
@ -80,9 +81,11 @@ PYBIND11_MODULE(_triton, m) {
py::class_<KernelCall>(m, "TritonKernelCall")
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
std::vector<KernelCall::Parameter>>())
.def("to_proto", [](const KernelCall& kernel_call, std::string metadata) {
.def("to_proto", [](const KernelCall& kernel_call, std::string name,
std::string metadata) {
jax_triton::TritonAnyKernelCall proto;
*proto.mutable_kernel_call() = kernel_call.ToProto();
proto.set_name(std::move(name));
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});
@ -102,10 +105,11 @@ PYBIND11_MODULE(_triton, m) {
std::move(name), std::move(configs),
std::move(input_output_aliases));
}))
.def("to_proto",
[](const AutotunedKernelCall& kernel_call, std::string metadata) {
.def("to_proto", [](const AutotunedKernelCall& kernel_call,
std::string name, std::string metadata) {
jax_triton::TritonAnyKernelCall proto;
*proto.mutable_autotuned_kernel_call() = kernel_call.ToProto();
proto.set_name(std::move(name));
proto.set_metadata(std::move(metadata));
return py::bytes(proto.SerializeAsString());
});
@ -123,15 +127,11 @@ PYBIND11_MODULE(_triton, m) {
return major * 10 + minor;
});
m.def(
"get_serialized_metadata",
m.def("get_serialized_metadata",
[](absl::string_view opaque) -> absl::StatusOr<py::bytes> {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
jax_triton::TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return py::bytes(proto.metadata());
JAX_ASSIGN_OR_RETURN(std::string metadata,
GetTritonKernelCallSerializedMetadata(opaque));
return py::bytes(metadata);
});
}

View File

@ -3,7 +3,7 @@ syntax = "proto3";
package jax_triton;
message TritonKernel {
string kernel_name = 1;
string kernel_name = 1; // Kernel function name within module.
uint32 num_warps = 2;
uint32 shared_mem_bytes = 3;
string ptx = 4;
@ -49,7 +49,7 @@ message TritonAutotunedKernelCall {
uint64 buffer_size_bytes = 3;
}
string name = 1;
string name = 1; // Name used in auto-tuning log messages.
repeated Config configs = 2;
repeated InputOutputAlias input_output_aliases = 3;
}
@ -60,4 +60,5 @@ message TritonAnyKernelCall {
TritonAutotunedKernelCall autotuned_kernel_call = 2;
}
bytes metadata = 3;
string name = 4; // User assigned name.
}

View File

@ -1,7 +1,5 @@
#include "jaxlib/gpu/triton_kernels.h"
#include <zlib.h>
#include <algorithm>
#include <cstdint>
#include <memory>
@ -23,6 +21,7 @@
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/triton_utils.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
#include "xla/stream_executor/gpu/asm_compiler.h"
@ -519,25 +518,4 @@ void TritonKernelCall(CUstream stream, void** buffers, const char* opaque,
}
}
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
std::string data;
uLongf dest_len = 5 * compressed.size();
while (true) {
data.resize(dest_len);
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
reinterpret_cast<const Bytef*>(compressed.data()),
compressed.size());
if (ret == Z_OK) {
// `uncompress` overwrites `dest_len` with the uncompressed size.
data.resize(dest_len);
break;
} else if (ret == Z_BUF_ERROR) {
dest_len *= 2; // The string buffer wasn't large enough.
} else {
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
}
}
return data;
}
} // namespace jax::JAX_GPU_NAMESPACE

View File

@ -8,10 +8,8 @@
#include <variant>
#include <vector>
#include "absl/cleanup/cleanup.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/triton.pb.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
@ -101,8 +99,6 @@ class AutotunedKernelCall {
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases_;
};
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed);
} // namespace jax::JAX_GPU_NAMESPACE
#endif // JAXLIB_GPU_TRITON_H_

View File

@ -0,0 +1,55 @@
#include "jaxlib/gpu/triton_utils.h"
#include <zlib.h>
#include <string>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/triton.pb.h"
namespace jax::JAX_GPU_NAMESPACE {
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed) {
std::string data;
uLongf dest_len = 5 * compressed.size();
while (true) {
data.resize(dest_len);
int ret = uncompress(reinterpret_cast<Bytef*>(data.data()), &dest_len,
reinterpret_cast<const Bytef*>(compressed.data()),
compressed.size());
if (ret == Z_OK) {
// `uncompress` overwrites `dest_len` with the uncompressed size.
data.resize(dest_len);
break;
} else if (ret == Z_BUF_ERROR) {
dest_len *= 2; // The string buffer wasn't large enough.
} else {
return absl::InvalidArgumentError("Failed to uncompress opaque data.");
}
}
return data;
}
absl::StatusOr<std::string> GetTritonKernelCallName(absl::string_view opaque) {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
jax_triton::TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return proto.name();
}
absl::StatusOr<std::string> GetTritonKernelCallSerializedMetadata(
absl::string_view opaque) {
JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque));
jax_triton::TritonAnyKernelCall proto;
if (!proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse serialized data.");
}
return proto.metadata();
}
} // namespace jax::JAX_GPU_NAMESPACE

20
jaxlib/gpu/triton_utils.h Normal file
View File

@ -0,0 +1,20 @@
#ifndef JAXLIB_GPU_TRITON_UTILS_H_
#define JAXLIB_GPU_TRITON_UTILS_H_
#include <string>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "jaxlib/gpu/vendor.h"
namespace jax::JAX_GPU_NAMESPACE {
absl::StatusOr<std::string> ZlibUncompress(absl::string_view compressed);
absl::StatusOr<std::string> GetTritonKernelCallName(absl::string_view opaque);
absl::StatusOr<std::string> GetTritonKernelCallSerializedMetadata(
absl::string_view opaque);
} // namespace jax::JAX_GPU_NAMESPACE
#endif // JAXLIB_GPU_TRITON_UTILS_H_