24 Commits

Author SHA1 Message Date
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
Dan Foreman-Mackey
5e915d3307 Update the sparse GPU kernels in jaxlib to use the FFI.
Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.

PiperOrigin-RevId: 724016331
2025-02-06 11:45:57 -08:00
jax authors
aa05dc0b5c Automated Code Change
PiperOrigin-RevId: 699991540
2024-11-25 08:31:06 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
Eric Salo
713e909ba0 cleanup: remove api_version from BUILD files
PiperOrigin-RevId: 683658237
2024-10-08 09:44:15 -07:00
Dan Foreman-Mackey
a3bf75e442 Refactor gpusolver kernel definitions into separate build target.
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.

PiperOrigin-RevId: 673832309
2024-09-12 07:11:36 -07:00
Sebastian Bodenstein
e3b8177af3 Internal change.
PiperOrigin-RevId: 671583042
2024-09-05 18:42:22 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Dan Foreman-Mackey
b6306e3953 Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.

PiperOrigin-RevId: 663686539
2024-08-16 04:37:09 -07:00
Dan Foreman-Mackey
ad1bd38790 Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
2024-08-14 09:20:40 -07:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
Dan Foreman-Mackey
4f394828e1 Fix C++ registration of FFI handlers and consolidate gpu/linalg kernel implementation.
This change does a few things (arguably too many):

1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.

2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.

3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.

PiperOrigin-RevId: 651107659
2024-07-10 12:09:12 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Chris Jones
4ac2bdc2b1 [jax_triton] Add user-specified name field to serialized format.
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Chris Jones
31b862dd56 [jax_triton] Split C++ only parts of Triton custom callback from Python parts.
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00