18 Commits

Author SHA1 Message Date
Rahul Batra
4b7c198a1c [ROCm]: Add get_arch_details for triton kernel call 2024-08-12 06:16:27 +00:00
jax authors
ab3c1b5146 [triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Peter Hawkins
ac8ea86103 Fix accidental signature change to get_serialized_metadata() from nanobind PR.
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.

PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Chris Jones
4ac2bdc2b1 [jax_triton] Add user-specified name field to serialized format.
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Chris Jones
714156df63 [jax_triton] Add support for float scalar inputs.
Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`.

PiperOrigin-RevId: 552036612
2023-07-28 22:57:08 -07:00
Chris Jones
31b862dd56 [jax_triton] Split C++ only parts of Triton custom callback from Python parts.
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
3f9da19c63 Add get_serialized_metadata function to retrieve metadata from op's opaque data.
PiperOrigin-RevId: 544608895
2023-06-30 03:23:28 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975 Zlib compress kernel proto.
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff Use EncapsulateFunction utility.
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Peter Zhizhin
01ed663163 Add a comment at the end of autotuning a Triton function
I'd like to see what config was chosen at the end of autotuning. Tracking `is the new best config` is a bit hard.

PiperOrigin-RevId: 538465212
2023-06-07 06:10:28 -07:00
jax authors
3ba308d4fc [triton] Raise exception on too much shared memory requested
PiperOrigin-RevId: 538135346
2023-06-06 03:55:53 -07:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00