66 Commits

Author SHA1 Message Date
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Andrey Portnoy
dcb58bb540 Include <cstdint> in files where it is used 2024-03-06 11:58:15 -05:00
jax authors
7514d5c7aa [triton] Add clustering support and test
PiperOrigin-RevId: 612417957
2024-03-04 05:51:10 -08:00
Eugene Zhulenev
1ae2022918 [jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
2024-03-01 10:28:47 -08:00
Eugene Zhulenev
3a69b80774 [jax-triton] Synchronize autotuning stream with a main one
PiperOrigin-RevId: 609792049
2024-02-23 11:42:30 -08:00
Chris Jones
fcc8b54789 [jax_triton] Use ReaderLock on fast path to reduce lock contention in multi-GPU settings.
PiperOrigin-RevId: 606648981
2024-02-13 09:31:50 -08:00
Anlun Xu
d62071066e [jax:triton] Add a workaround for calling cuStreamGetCtx inside graph capture
A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.

PiperOrigin-RevId: 605417225
2024-02-08 13:49:45 -08:00
Rahul Batra
f01c27f65a [ROCm]: Add ROCm command buffer support for triton kernel 2024-02-05 19:34:12 +00:00
Anlun Xu
16636f9c97 [jax_triton] Only use side stream to do autotuning when doing graph capture
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Anlun Xu
5e009f9ff1 Make triton kernels compatible with command buffers
Autotuning is not compatible with graph capture because it requires synchronizing.

We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)

PiperOrigin-RevId: 602436960
2024-01-29 11:00:29 -08:00
Anlun Xu
88f5eaca3e [xla:gpu] Make cu_threefry2x32 custom call compatible with command buffers
PiperOrigin-RevId: 600937786
2024-01-23 16:14:21 -08:00
jax authors
5761e393fa The Triton autotuner ignores configs that use too much shmem
The autotuner runs a series of benchmarks to determine the best configuration
for a Triton kernel. However, if it encounters a config that does not fit in
shared memory it throws an error and stops. I this eventuality it should just
continue.

PiperOrigin-RevId: 600730507
2024-01-23 03:08:57 -08:00
Rahul Batra
f997609e76 [ROCm]: Updates hip headers path for ROCm 6.0 2024-01-22 16:08:37 +00:00
jax authors
ab3c1b5146 [triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
jax authors
59ea9f3fde [triton] Use cuLaunchKernelEx instead of cuLaunchKernel
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00
jax authors
88fe0da6d1 Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2023-10-18 11:56:01 -07:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Peter Hawkins
07fa9dc3db Fix cupti-related build failure under CUDA 11.
cuptiGetErrorMessage was added in CUDA 12.2.

PiperOrigin-RevId: 568962562
2023-09-27 14:33:30 -07:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Andrey Portnoy
fc1c31d958 Run LSTM test using FP32 math (as opposed to TF32)
1. Add (limited) precision specifier handling to LSTM

Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.

2. Run the test using FP32

TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
2023-09-19 14:45:14 -04:00
Chris Jones
9f7a19ad50 [jax_triton] Improve error message when shared memory exceeds that available on the GPU.
PiperOrigin-RevId: 561792406
2023-08-31 16:32:18 -07:00
Peter Hawkins
46ac9e2170 Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.

PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
34010a9e4a Align dummy pointers passed to cusparse to 16 bytes
Fixes alignment errors from Cusparse 12.2.

PiperOrigin-RevId: 560793586
2023-08-28 12:56:27 -07:00
Peter Hawkins
ac8ea86103 Fix accidental signature change to get_serialized_metadata() from nanobind PR.
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.

PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Chris Jones
4ac2bdc2b1 [jax_triton] Add user-specified name field to serialized format.
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Srinivas Vasudevan
7dfc8ff49d Add batching rules to jax.lax.linalg.tridiagonal_solve.
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
Chris Jones
714156df63 [jax_triton] Add support for float scalar inputs.
Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`.

PiperOrigin-RevId: 552036612
2023-07-28 22:57:08 -07:00
Chris Jones
9935445d57 [jax_triton] Simplify auto-tuning code.
PiperOrigin-RevId: 545733541
2023-07-05 11:18:18 -07:00
Chris Jones
31b862dd56 [jax_triton] Split C++ only parts of Triton custom callback from Python parts.
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
3f9da19c63 Add get_serialized_metadata function to retrieve metadata from op's opaque data.
PiperOrigin-RevId: 544608895
2023-06-30 03:23:28 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975 Zlib compress kernel proto.
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff Use EncapsulateFunction utility.
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Peter Zhizhin
01ed663163 Add a comment at the end of autotuning a Triton function
I'd like to see what config was chosen at the end of autotuning. Tracking `is the new best config` is a bit hard.

PiperOrigin-RevId: 538465212
2023-06-07 06:10:28 -07:00
jax authors
3ba308d4fc [triton] Raise exception on too much shared memory requested
PiperOrigin-RevId: 538135346
2023-06-06 03:55:53 -07:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
a89c377762 [GPU] Fix another instance of missing stream synchronization in RNN kernels.
PiperOrigin-RevId: 530660502
2023-05-09 11:08:24 -07:00
Peter Hawkins
f168a1560c [GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
May fix flaky failures in CI.

Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.

PiperOrigin-RevId: 530425766
2023-05-08 15:37:04 -07:00
Peter Hawkins
6b9a109939 Use stream-synchronized copy in rnn_kernels.cc.
May fix flaky wrong outputs sometimes seen in CI.

Also check for errors in another use of gpuStreamSynchronize().

PiperOrigin-RevId: 530391917
2023-05-08 13:28:08 -07:00
George Necula
a2ac510dc3 [shape_poly] Add support for dynamic shapes for eigh
We can only handle dynamic sizes for the batch dimensions for now.

PiperOrigin-RevId: 529001830
2023-05-02 23:27:59 -07:00
Matthew Johnson
56feaca7f9 update cuDNN RNN code not to save 'workspace' scratch between fwd and bwd
PiperOrigin-RevId: 528928263
2023-05-02 17:05:42 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00