25 Commits

Author SHA1 Message Date
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Dan Foreman-Mackey
bd90968a25 Port the GPU Cholesky update custom call to the FFI.
PiperOrigin-RevId: 665319689
2024-08-20 05:46:03 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Eugene Zhulenev
3a69b80774 [jax-triton] Synchronize autotuning stream with a main one
PiperOrigin-RevId: 609792049
2024-02-23 11:42:30 -08:00
Anlun Xu
d62071066e [jax:triton] Add a workaround for calling cuStreamGetCtx inside graph capture
A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.

PiperOrigin-RevId: 605417225
2024-02-08 13:49:45 -08:00
Rahul Batra
f01c27f65a [ROCm]: Add ROCm command buffer support for triton kernel 2024-02-05 19:34:12 +00:00
Anlun Xu
16636f9c97 [jax_triton] Only use side stream to do autotuning when doing graph capture
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Anlun Xu
5e009f9ff1 Make triton kernels compatible with command buffers
Autotuning is not compatible with graph capture because it requires synchronizing.

We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)

PiperOrigin-RevId: 602436960
2024-01-29 11:00:29 -08:00
Rahul Batra
f997609e76 [ROCm]: Updates hip headers path for ROCm 6.0 2024-01-22 16:08:37 +00:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Peter Hawkins
46ac9e2170 Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.

PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Parker Schuh
0324cac888 Remove unused potrf kernels.
PiperOrigin-RevId: 489322021
2022-11-17 15:22:13 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Tianjian Lu
46368e4e73 [sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards.
PiperOrigin-RevId: 486051658
2022-11-03 21:39:52 -07:00
Tianjian Lu
ef0f64ec5c [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5 Rollback of:
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.

Still breaks CUDA 11.1

PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00