From 4ac2bdc2b1d71ec0010412a3248ad567f145f773 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 16 Aug 2023 02:53:05 -0700 Subject: [PATCH] [jax_triton] Add user-specified `name` field to serialized format. PiperOrigin-RevId: 557415723 --- jax/_src/pallas/triton/lowering.py | 18 ++++++---- jaxlib/cuda/BUILD | 16 +++++++++ jaxlib/gpu/BUILD | 2 ++ jaxlib/gpu/triton.cc | 36 +++++++++---------- jaxlib/gpu/triton.proto | 5 +-- jaxlib/gpu/triton_kernels.cc | 24 +------------ jaxlib/gpu/triton_kernels.h | 4 --- jaxlib/gpu/triton_utils.cc | 55 ++++++++++++++++++++++++++++++ jaxlib/gpu/triton_utils.h | 20 +++++++++++ 9 files changed, 126 insertions(+), 54 deletions(-) create mode 100644 jaxlib/gpu/triton_utils.cc create mode 100644 jaxlib/gpu/triton_utils.h diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 93ed744cf..fdb615109 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -35,7 +35,13 @@ from jax._src import util from jax._src.lax.control_flow import for_loop from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib import hlo_helpers +from jax._src.lib import version from jax._src.lib.mlir import ir +from jax._src.pallas import core as pallas_core +from jax._src.pallas import indexing +from jax._src.pallas import primitives +from jax._src.pallas import utils as pallas_utils +from jax._src.pallas.pallas_call import pallas_call_p from jax._src.state import AbstractRef from jax._src.state import discharge from jax._src.state import primitives as sp @@ -47,11 +53,6 @@ from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.lib import xla_client as xc import jax.numpy as jnp -from jax._src.pallas import core as pallas_core -from jax._src.pallas.pallas_call import pallas_call_p -from jax._src.pallas import primitives -from jax._src.pallas import indexing -from jax._src.pallas import utils as pallas_utils from jax_triton import triton_lib from jax_triton.triton_lib import compile_ttir_to_ptx_inplace from jax_triton.triton_lib import get_triton_type @@ -1687,12 +1688,15 @@ def pallas_call_lowering( if triton_params is None: triton_params = {} serialized_metadata = triton_params.get("serialized_metadata", b"") - + if version >= (0, 4, 15): + kernel_call_proto = kernel_call.to_proto(name, serialized_metadata) + else: + kernel_call_proto = kernel_call.to_proto(serialized_metadata) return hlo_helpers.custom_call( call_target_name=name, out_types=out_types, operands=in_nodes, - backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)), + backend_config=zlib.compress(kernel_call_proto), operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in), result_layouts=triton_lib.avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 21ab23b05..eb5daa5be 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -390,6 +390,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + ":triton_utils", "//jaxlib/gpu:triton_cc_proto", "@xla//xla/service:custom_call_status", "@xla//xla/stream_executor/cuda:cudart_stub", @@ -403,6 +404,20 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "triton_utils", + srcs = ["//jaxlib/gpu:triton_utils.cc"], + hdrs = ["//jaxlib/gpu:triton_utils.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@zlib", ], ) @@ -426,6 +441,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":triton_kernels", + ":triton_utils", "//jaxlib:kernel_pybind11_helpers", "//jaxlib/gpu:triton_cc_proto", "@com_google_absl//absl/status:statusor", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 88a45b46b..2d8c11757 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -50,6 +50,8 @@ exports_files(srcs = [ "triton.cc", "triton_kernels.cc", "triton_kernels.h", + "triton_utils.cc", + "triton_utils.h", "vendor.h", ]) diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 57ebeaf45..253a4dfa9 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -12,6 +12,7 @@ #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/triton_kernels.h" +#include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_pybind11_helpers.h" #include "pybind11_abseil/status_casters.h" // IWYU pragma: keep @@ -80,9 +81,11 @@ PYBIND11_MODULE(_triton, m) { py::class_(m, "TritonKernelCall") .def(py::init>()) - .def("to_proto", [](const KernelCall& kernel_call, std::string metadata) { + .def("to_proto", [](const KernelCall& kernel_call, std::string name, + std::string metadata) { jax_triton::TritonAnyKernelCall proto; *proto.mutable_kernel_call() = kernel_call.ToProto(); + proto.set_name(std::move(name)); proto.set_metadata(std::move(metadata)); return py::bytes(proto.SerializeAsString()); }); @@ -102,13 +105,14 @@ PYBIND11_MODULE(_triton, m) { std::move(name), std::move(configs), std::move(input_output_aliases)); })) - .def("to_proto", - [](const AutotunedKernelCall& kernel_call, std::string metadata) { - jax_triton::TritonAnyKernelCall proto; - *proto.mutable_autotuned_kernel_call() = kernel_call.ToProto(); - proto.set_metadata(std::move(metadata)); - return py::bytes(proto.SerializeAsString()); - }); + .def("to_proto", [](const AutotunedKernelCall& kernel_call, + std::string name, std::string metadata) { + jax_triton::TritonAnyKernelCall proto; + *proto.mutable_autotuned_kernel_call() = kernel_call.ToProto(); + proto.set_name(std::move(name)); + proto.set_metadata(std::move(metadata)); + return py::bytes(proto.SerializeAsString()); + }); m.def("get_custom_call", [] { return EncapsulateFunction(&TritonKernelCall); }); @@ -123,16 +127,12 @@ PYBIND11_MODULE(_triton, m) { return major * 10 + minor; }); - m.def( - "get_serialized_metadata", - [](absl::string_view opaque) -> absl::StatusOr { - JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); - jax_triton::TritonAnyKernelCall proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError("Failed to parse serialized data."); - } - return py::bytes(proto.metadata()); - }); + m.def("get_serialized_metadata", + [](absl::string_view opaque) -> absl::StatusOr { + JAX_ASSIGN_OR_RETURN(std::string metadata, + GetTritonKernelCallSerializedMetadata(opaque)); + return py::bytes(metadata); + }); } } // namespace jax::JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/triton.proto b/jaxlib/gpu/triton.proto index bbb2ff218..7d3037a28 100644 --- a/jaxlib/gpu/triton.proto +++ b/jaxlib/gpu/triton.proto @@ -3,7 +3,7 @@ syntax = "proto3"; package jax_triton; message TritonKernel { - string kernel_name = 1; + string kernel_name = 1; // Kernel function name within module. uint32 num_warps = 2; uint32 shared_mem_bytes = 3; string ptx = 4; @@ -49,7 +49,7 @@ message TritonAutotunedKernelCall { uint64 buffer_size_bytes = 3; } - string name = 1; + string name = 1; // Name used in auto-tuning log messages. repeated Config configs = 2; repeated InputOutputAlias input_output_aliases = 3; } @@ -60,4 +60,5 @@ message TritonAnyKernelCall { TritonAutotunedKernelCall autotuned_kernel_call = 2; } bytes metadata = 3; + string name = 4; // User assigned name. } diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 311e0e9c3..41ab5cc64 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,7 +1,5 @@ #include "jaxlib/gpu/triton_kernels.h" -#include - #include #include #include @@ -23,6 +21,7 @@ #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" +#include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" #include "xla/stream_executor/gpu/asm_compiler.h" @@ -519,25 +518,4 @@ void TritonKernelCall(CUstream stream, void** buffers, const char* opaque, } } -absl::StatusOr ZlibUncompress(absl::string_view compressed) { - std::string data; - uLongf dest_len = 5 * compressed.size(); - while (true) { - data.resize(dest_len); - int ret = uncompress(reinterpret_cast(data.data()), &dest_len, - reinterpret_cast(compressed.data()), - compressed.size()); - if (ret == Z_OK) { - // `uncompress` overwrites `dest_len` with the uncompressed size. - data.resize(dest_len); - break; - } else if (ret == Z_BUF_ERROR) { - dest_len *= 2; // The string buffer wasn't large enough. - } else { - return absl::InvalidArgumentError("Failed to uncompress opaque data."); - } - } - return data; -} - } // namespace jax::JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index 20a822269..c625a168c 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -8,10 +8,8 @@ #include #include -#include "absl/cleanup/cleanup.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" @@ -101,8 +99,6 @@ class AutotunedKernelCall { std::vector> input_output_aliases_; }; -absl::StatusOr ZlibUncompress(absl::string_view compressed); - } // namespace jax::JAX_GPU_NAMESPACE #endif // JAXLIB_GPU_TRITON_H_ diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc new file mode 100644 index 000000000..b3a077911 --- /dev/null +++ b/jaxlib/gpu/triton_utils.cc @@ -0,0 +1,55 @@ +#include "jaxlib/gpu/triton_utils.h" + +#include + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/triton.pb.h" + +namespace jax::JAX_GPU_NAMESPACE { + +absl::StatusOr ZlibUncompress(absl::string_view compressed) { + std::string data; + uLongf dest_len = 5 * compressed.size(); + while (true) { + data.resize(dest_len); + int ret = uncompress(reinterpret_cast(data.data()), &dest_len, + reinterpret_cast(compressed.data()), + compressed.size()); + if (ret == Z_OK) { + // `uncompress` overwrites `dest_len` with the uncompressed size. + data.resize(dest_len); + break; + } else if (ret == Z_BUF_ERROR) { + dest_len *= 2; // The string buffer wasn't large enough. + } else { + return absl::InvalidArgumentError("Failed to uncompress opaque data."); + } + } + return data; +} + +absl::StatusOr GetTritonKernelCallName(absl::string_view opaque) { + JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); + jax_triton::TritonAnyKernelCall proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError("Failed to parse serialized data."); + } + return proto.name(); +} + +absl::StatusOr GetTritonKernelCallSerializedMetadata( + absl::string_view opaque) { + JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); + jax_triton::TritonAnyKernelCall proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError("Failed to parse serialized data."); + } + return proto.metadata(); +} + +} // namespace jax::JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h new file mode 100644 index 000000000..0c286391e --- /dev/null +++ b/jaxlib/gpu/triton_utils.h @@ -0,0 +1,20 @@ +#ifndef JAXLIB_GPU_TRITON_UTILS_H_ +#define JAXLIB_GPU_TRITON_UTILS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax::JAX_GPU_NAMESPACE { + +absl::StatusOr ZlibUncompress(absl::string_view compressed); +absl::StatusOr GetTritonKernelCallName(absl::string_view opaque); +absl::StatusOr GetTritonKernelCallSerializedMetadata( + absl::string_view opaque); + +} // namespace jax::JAX_GPU_NAMESPACE + +#endif // JAXLIB_GPU_TRITON_UTILS_H_