mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00

This allows users to register the callback from C++ when not using the default call target name. PiperOrigin-RevId: 544029098
11 lines
318 B
C
11 lines
318 B
C
#ifndef JAXLIB_GPU_TRITON_H_
|
|
#define JAXLIB_GPU_TRITON_H_
|
|
|
|
#include "jaxlib/gpu/vendor.h"
|
|
#include "xla/service/custom_call_status.h"
|
|
|
|
void LaunchTritonKernel(CUstream stream, void** buffers, const char* opaque,
|
|
size_t opaque_len, XlaCustomCallStatus* status);
|
|
|
|
#endif // JAXLIB_GPU_TRITON_H_
|