74 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
9ae1c56c44 Update lu_pivots_to_permutation to use FFI dimensions on GPU.
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.

PiperOrigin-RevId: 647343656
2024-06-27 09:27:15 -07:00
Ruturaj4
79fccf6c82 add cholesky changes in bazel 2024-05-18 00:37:09 +00:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergei Lebedev
51fc4f85ad Ported LuPivotsToPermutation to the typed XLA FFI
The typed FFI

* allows passing custom call attributes directly to backend_config= instead
  of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.

PiperOrigin-RevId: 630067005
2024-05-02 08:12:05 -07:00
Marvin Kim
90e9e47a55 [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark.

PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
Marvin Kim
722708052c [JAX] Fix typo in comment.
PiperOrigin-RevId: 621827985
2024-04-04 05:35:28 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Andrey Portnoy
dcb58bb540 Include <cstdint> in files where it is used 2024-03-06 11:58:15 -05:00
jax authors
7514d5c7aa [triton] Add clustering support and test
PiperOrigin-RevId: 612417957
2024-03-04 05:51:10 -08:00
Eugene Zhulenev
1ae2022918 [jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
2024-03-01 10:28:47 -08:00
Eugene Zhulenev
3a69b80774 [jax-triton] Synchronize autotuning stream with a main one
PiperOrigin-RevId: 609792049
2024-02-23 11:42:30 -08:00
Chris Jones
fcc8b54789 [jax_triton] Use ReaderLock on fast path to reduce lock contention in multi-GPU settings.
PiperOrigin-RevId: 606648981
2024-02-13 09:31:50 -08:00
Anlun Xu
d62071066e [jax:triton] Add a workaround for calling cuStreamGetCtx inside graph capture
A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.

PiperOrigin-RevId: 605417225
2024-02-08 13:49:45 -08:00
Rahul Batra
f01c27f65a [ROCm]: Add ROCm command buffer support for triton kernel 2024-02-05 19:34:12 +00:00
Anlun Xu
16636f9c97 [jax_triton] Only use side stream to do autotuning when doing graph capture
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Anlun Xu
5e009f9ff1 Make triton kernels compatible with command buffers
Autotuning is not compatible with graph capture because it requires synchronizing.

We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)

PiperOrigin-RevId: 602436960
2024-01-29 11:00:29 -08:00
Anlun Xu
88f5eaca3e [xla:gpu] Make cu_threefry2x32 custom call compatible with command buffers
PiperOrigin-RevId: 600937786
2024-01-23 16:14:21 -08:00
jax authors
5761e393fa The Triton autotuner ignores configs that use too much shmem
The autotuner runs a series of benchmarks to determine the best configuration
for a Triton kernel. However, if it encounters a config that does not fit in
shared memory it throws an error and stops. I this eventuality it should just
continue.

PiperOrigin-RevId: 600730507
2024-01-23 03:08:57 -08:00
Rahul Batra
f997609e76 [ROCm]: Updates hip headers path for ROCm 6.0 2024-01-22 16:08:37 +00:00
jax authors
ab3c1b5146 [triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
jax authors
59ea9f3fde [triton] Use cuLaunchKernelEx instead of cuLaunchKernel
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00
jax authors
88fe0da6d1 Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2023-10-18 11:56:01 -07:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Peter Hawkins
07fa9dc3db Fix cupti-related build failure under CUDA 11.
cuptiGetErrorMessage was added in CUDA 12.2.

PiperOrigin-RevId: 568962562
2023-09-27 14:33:30 -07:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Andrey Portnoy
fc1c31d958 Run LSTM test using FP32 math (as opposed to TF32)
1. Add (limited) precision specifier handling to LSTM

Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.

2. Run the test using FP32

TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
2023-09-19 14:45:14 -04:00
Chris Jones
9f7a19ad50 [jax_triton] Improve error message when shared memory exceeds that available on the GPU.
PiperOrigin-RevId: 561792406
2023-08-31 16:32:18 -07:00
Peter Hawkins
46ac9e2170 Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.

PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
34010a9e4a Align dummy pointers passed to cusparse to 16 bytes
Fixes alignment errors from Cusparse 12.2.

PiperOrigin-RevId: 560793586
2023-08-28 12:56:27 -07:00
Peter Hawkins
ac8ea86103 Fix accidental signature change to get_serialized_metadata() from nanobind PR.
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.

PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Chris Jones
4ac2bdc2b1 [jax_triton] Add user-specified name field to serialized format.
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Srinivas Vasudevan
7dfc8ff49d Add batching rules to jax.lax.linalg.tridiagonal_solve.
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
Chris Jones
714156df63 [jax_triton] Add support for float scalar inputs.
Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`.

PiperOrigin-RevId: 552036612
2023-07-28 22:57:08 -07:00
Chris Jones
9935445d57 [jax_triton] Simplify auto-tuning code.
PiperOrigin-RevId: 545733541
2023-07-05 11:18:18 -07:00
Chris Jones
31b862dd56 [jax_triton] Split C++ only parts of Triton custom callback from Python parts.
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
3f9da19c63 Add get_serialized_metadata function to retrieve metadata from op's opaque data.
PiperOrigin-RevId: 544608895
2023-06-30 03:23:28 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975 Zlib compress kernel proto.
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff Use EncapsulateFunction utility.
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Peter Zhizhin
01ed663163 Add a comment at the end of autotuning a Triton function
I'd like to see what config was chosen at the end of autotuning. Tracking `is the new best config` is a bit hard.

PiperOrigin-RevId: 538465212
2023-06-07 06:10:28 -07:00
jax authors
3ba308d4fc [triton] Raise exception on too much shared memory requested
PiperOrigin-RevId: 538135346
2023-06-06 03:55:53 -07:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00