Rahul Batra
4b7c198a1c
[ROCm]: Add get_arch_details for triton kernel call
2024-08-12 06:16:27 +00:00
jax authors
ab3c1b5146
[triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
...
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
Rahul Batra
b4b97cd8e8
[ROCm]: Add jax-triton support for ROCm
2023-10-18 07:09:20 +00:00
Peter Hawkins
ac8ea86103
Fix accidental signature change to get_serialized_metadata() from nanobind PR.
...
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.
PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
Peter Hawkins
70b7d50181
Switch jaxlib to use nanobind instead of pybind11.
...
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html ), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.
PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Chris Jones
4ac2bdc2b1
[jax_triton] Add user-specified name
field to serialized format.
...
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Chris Jones
714156df63
[jax_triton] Add support for float scalar inputs.
...
Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`.
PiperOrigin-RevId: 552036612
2023-07-28 22:57:08 -07:00
Chris Jones
31b862dd56
[jax_triton] Split C++ only parts of Triton custom callback from Python parts.
...
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).
PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
3f9da19c63
Add get_serialized_metadata
function to retrieve metadata from op's opaque data.
...
PiperOrigin-RevId: 544608895
2023-06-30 03:23:28 -07:00
Chris Jones
d4e2464340
[jax_triton] Expose Triton custom call callback in header file.
...
This allows users to register the callback from C++ when not using the default call target name.
PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975
Zlib compress kernel proto.
...
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492
Make JAX-Triton calls serializable.
...
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff
Use EncapsulateFunction
utility.
...
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Peter Zhizhin
01ed663163
Add a comment at the end of autotuning a Triton function
...
I'd like to see what config was chosen at the end of autotuning. Tracking `is the new best config` is a bit hard.
PiperOrigin-RevId: 538465212
2023-06-07 06:10:28 -07:00
jax authors
3ba308d4fc
[triton] Raise exception on too much shared memory requested
...
PiperOrigin-RevId: 538135346
2023-06-06 03:55:53 -07:00
Chris Jones
ea37043577
Switch to STATUS_RETURNING
callback API.
...
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f
Switch to using JAX status macros in jax-triton kernel call lib.
...
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Sharad Vikram
bf8ed6a543
Move triton_kernel_call_lib to jaxlib
...
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00