126 Commits

Author SHA1 Message Date
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
jax authors
538a2be7fe Reverts 74b4d868e3751c1b4efa315ff8cf771faeb0b663
PiperOrigin-RevId: 736650031
2025-03-13 14:59:09 -07:00
Dan Foreman-Mackey
8b7cfcb33c Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
2025-03-12 08:22:04 -07:00
Chris Jones
74b4d868e3 Add support for scratch buffers in jax_triton.
This is required to use device-side TMA descriptors.

PiperOrigin-RevId: 735985603
2025-03-11 20:49:33 -07:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Dan Foreman-Mackey
5bc17f7ec3 Remove the unused cu_cholesky_update kernel in favor of the FFI version.
This kernel wasn't allowed in export, so no backwards compatibility period is required. Even so, the FFI kernels were added 6 months ago.

PiperOrigin-RevId: 724359996
2025-02-07 08:48:15 -08:00
Dan Foreman-Mackey
c6e83903de Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
2025-02-06 18:27:58 -08:00
Dan Foreman-Mackey
5e915d3307 Update the sparse GPU kernels in jaxlib to use the FFI.
Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.

PiperOrigin-RevId: 724016331
2025-02-06 11:45:57 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
jax authors
41993fdb24 Merge pull request #25755 from ROCm:ci_rnn_final-upstream
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
Ruturaj4
fe68eb8b25 [ROCm] Implement RNN support 2025-01-14 19:04:49 -06:00
Peter Hawkins
91ffb640a8 Use thread-safe initialization of LAPACK kernels.
Use absl::call_once instead of a GIL-protected global initialization.

In passing, also remove an unused function.

PiperOrigin-RevId: 714892175
2025-01-13 02:51:38 -08:00
Ilia Sergachev
f0e1c3cf36 Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
2025-01-09 12:57:09 +00:00
Dan Foreman-Mackey
32df37e6e4 Port symmetric tridiagonal reduction GPU kernel to FFI.
PiperOrigin-RevId: 704382200
2024-12-09 12:41:23 -08:00
Goran Flegar
c4d19ca83c Integrate Triton up to [9732c047](9732c04701)
PiperOrigin-RevId: 702397897
2024-12-03 10:55:56 -08:00
jax authors
aa05dc0b5c Automated Code Change
PiperOrigin-RevId: 699991540
2024-11-25 08:31:06 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
Dan Foreman-Mackey
21e98b5ce4 Fix overflow error in GPU batched linear algebra kernels.
As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
2024-11-12 05:33:49 -08:00
Yuchen Jin
218f763255
(follow-up of PR #23852) add missing typename keyword to work with gcc
This update is a follow-up of PR #23852. In the previous PR, there was one missing place where the `typename` was not added.
2024-11-07 23:55:38 -06:00
Dan Foreman-Mackey
6625a2b3ed Update Eigh kernel on GPU to use 64-bit interface when it is available.
Part of https://github.com/jax-ml/jax/issues/23413

PiperOrigin-RevId: 684546802
2024-10-10 12:59:37 -07:00
Eric Salo
713e909ba0 cleanup: remove api_version from BUILD files
PiperOrigin-RevId: 683658237
2024-10-08 09:44:15 -07:00
Ruturaj4
29a1cb766e [ROCM] add missing typename keyword to work with gcc 2024-09-23 14:42:01 -05:00
Dan Foreman-Mackey
afaa3bf43c Port GPU kernels for SVD to the FFI.
Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:

1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.

2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.

Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413.

PiperOrigin-RevId: 676839182
2024-09-20 07:34:50 -07:00
Dan Foreman-Mackey
a3bf75e442 Refactor gpusolver kernel definitions into separate build target.
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.

PiperOrigin-RevId: 673832309
2024-09-12 07:11:36 -07:00
Dan Foreman-Mackey
7266e338c8 Update FFI target name for syrk operation to be consistent with other kernels.
PiperOrigin-RevId: 671870569
2024-09-06 13:21:38 -07:00
Dan Foreman-Mackey
1d12a9934c Port GPU kernel for symmetric eigendecomposition to GPU.
Of note, I moved the logic about which algorithm to use, and when to use the batched algorithm into the kernel in order to support shape polymorphism and export.

PiperOrigin-RevId: 671853879
2024-09-06 12:23:04 -07:00
jax authors
f97bfc85a3 Implement symmetric_product() to produce a symmetric matrix: C = alpha * X @ X.T + beta * C
PiperOrigin-RevId: 671845818
2024-09-06 11:58:20 -07:00
Sebastian Bodenstein
e3b8177af3 Internal change.
PiperOrigin-RevId: 671583042
2024-09-05 18:42:22 -07:00
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Dan Foreman-Mackey
b56ed8eedd Port GPU kernel for Householder transformation to FFI.
PiperOrigin-RevId: 666305682
2024-08-22 05:23:09 -07:00
Krishna Haridasan
3713b966c2 Fix a potential segfault in triton kernel call caching
It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.

PiperOrigin-RevId: 666161091
2024-08-21 20:45:35 -07:00
Dan Foreman-Mackey
bd90968a25 Port the GPU Cholesky update custom call to the FFI.
PiperOrigin-RevId: 665319689
2024-08-20 05:46:03 -07:00
Dan Foreman-Mackey
71a93d0c87 Port QR factorization GPU kernel to FFI.
The biggest change here is that we now ignore the `info` parameter that is returned by `getrf`. In the previous implementation, we would return an error in the batched implementation, or set the relevant matrix entries to NaN in the non-batched version if `info != 0`. But, since info is only used for shape checking (see LAPACK, cuBLAS and cuSolver docs), I argue that we will never see `info != 0`, because we're including all the shape checks in the kernel already.

PiperOrigin-RevId: 665307128
2024-08-20 05:07:04 -07:00
Dan Foreman-Mackey
30d54ec6ff Refactor FFI shape inference functions to include dimension check.
Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr.

As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses.

PiperOrigin-RevId: 664803225
2024-08-19 07:41:28 -07:00
Dan Foreman-Mackey
b6306e3953 Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.

PiperOrigin-RevId: 663686539
2024-08-16 04:37:09 -07:00
Dan Foreman-Mackey
ad1bd38790 Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
2024-08-14 09:20:40 -07:00
jax authors
be4d52b814 Merge pull request #22667 from ROCm:rocm-jax-triton-add-get_arch_detail
PiperOrigin-RevId: 662007143
2024-08-12 02:30:49 -07:00
Rahul Batra
4b7c198a1c [ROCm]: Add get_arch_details for triton kernel call 2024-08-12 06:16:27 +00:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Dan Foreman-Mackey
33a9db3943 Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
2024-07-10 15:09:45 -07:00
Dan Foreman-Mackey
4f394828e1 Fix C++ registration of FFI handlers and consolidate gpu/linalg kernel implementation.
This change does a few things (arguably too many):

1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.

2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.

3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.

PiperOrigin-RevId: 651107659
2024-07-10 12:09:12 -07:00
Eugene Zhulenev
d49a0c5a63 [jax] Remove dead code from JAX custom calls defined as FFI handlers
PiperOrigin-RevId: 651025363
2024-07-10 08:11:12 -07:00
Eugene Zhulenev
1e03917c43 [xla:ffi] Use lazy decoding for Buffer<dtype,rank>
name                old cpu/op   new cpu/op   delta
BM_AnyBufferArgX1   11.0ns ± 3%  11.2ns ±10%   +1.76%  (p=0.000 n=67+69)
BM_AnyBufferArgX4   12.4ns ± 3%  12.4ns ± 4%   -0.31%  (p=0.006 n=69+69)
BM_BufferArgX1      12.5ns ± 1%  11.1ns ± 4%  -11.20%  (p=0.000 n=62+76)
BM_BufferArgX4      19.1ns ± 1%  14.4ns ± 4%  -24.84%  (p=0.000 n=64+73)
BM_BufferArgX8      36.0ns ± 5%  20.3ns ± 4%  -43.59%  (p=0.000 n=79+75)
BM_TupleOfI32Attrs  66.4ns ± 1%  66.4ns ± 2%   -0.03%  (p=0.000 n=66+72)

PiperOrigin-RevId: 650691450
2024-07-09 11:07:25 -07:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00