[jax_triton] Expose Triton custom call callback in header file.

This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
This commit is contained in:
Chris Jones 2023-06-28 05:31:20 -07:00 committed by jax authors
parent 5b698c899e
commit d4e2464340
4 changed files with 14 additions and 0 deletions

View File

@ -384,6 +384,7 @@ cc_library(
pybind_extension(
name = "_triton",
srcs = ["//jaxlib/gpu:triton.cc"],
hdrs = ["//jaxlib/gpu:triton.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",

View File

@ -48,6 +48,7 @@ exports_files(srcs = [
"sparse_kernels.cc",
"sparse_kernels.h",
"triton.cc",
"triton.h",
"vendor.h",
])

View File

@ -1,3 +1,5 @@
#include "jaxlib/gpu/triton.h"
#include <zlib.h>
#include <algorithm>

10
jaxlib/gpu/triton.h Normal file
View File

@ -0,0 +1,10 @@
#ifndef JAXLIB_GPU_TRITON_H_
#define JAXLIB_GPU_TRITON_H_
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
void LaunchTritonKernel(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // JAXLIB_GPU_TRITON_H_