rocm_jax/jaxlib/gpu/triton.h
Chris Jones d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00

11 lines
318 B
C

#ifndef JAXLIB_GPU_TRITON_H_
#define JAXLIB_GPU_TRITON_H_
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
void LaunchTritonKernel(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // JAXLIB_GPU_TRITON_H_