mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name. PiperOrigin-RevId: 544029098
This commit is contained in:
parent
5b698c899e
commit
d4e2464340
@ -384,6 +384,7 @@ cc_library(
|
||||
pybind_extension(
|
||||
name = "_triton",
|
||||
srcs = ["//jaxlib/gpu:triton.cc"],
|
||||
hdrs = ["//jaxlib/gpu:triton.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
|
@ -48,6 +48,7 @@ exports_files(srcs = [
|
||||
"sparse_kernels.cc",
|
||||
"sparse_kernels.h",
|
||||
"triton.cc",
|
||||
"triton.h",
|
||||
"vendor.h",
|
||||
])
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
#include "jaxlib/gpu/triton.h"
|
||||
|
||||
#include <zlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
10
jaxlib/gpu/triton.h
Normal file
10
jaxlib/gpu/triton.h
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef JAXLIB_GPU_TRITON_H_
|
||||
#define JAXLIB_GPU_TRITON_H_
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
void LaunchTritonKernel(CUstream stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
#endif // JAXLIB_GPU_TRITON_H_
|
Loading…
x
Reference in New Issue
Block a user