diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index be81d122c..21d1cb319 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -384,6 +384,7 @@ cc_library( pybind_extension( name = "_triton", srcs = ["//jaxlib/gpu:triton.cc"], + hdrs = ["//jaxlib/gpu:triton.h"], copts = [ "-fexceptions", "-fno-strict-aliasing", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index da6afbbff..3b011ee2b 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -48,6 +48,7 @@ exports_files(srcs = [ "sparse_kernels.cc", "sparse_kernels.h", "triton.cc", + "triton.h", "vendor.h", ]) diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index fcbd64038..4ad764472 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -1,3 +1,5 @@ +#include "jaxlib/gpu/triton.h" + #include #include diff --git a/jaxlib/gpu/triton.h b/jaxlib/gpu/triton.h new file mode 100644 index 000000000..307a0fbdd --- /dev/null +++ b/jaxlib/gpu/triton.h @@ -0,0 +1,10 @@ +#ifndef JAXLIB_GPU_TRITON_H_ +#define JAXLIB_GPU_TRITON_H_ + +#include "jaxlib/gpu/vendor.h" +#include "xla/service/custom_call_status.h" + +void LaunchTritonKernel(CUstream stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status); + +#endif // JAXLIB_GPU_TRITON_H_