13 Commits

Author SHA1 Message Date
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
Dan Foreman-Mackey
5bc17f7ec3 Remove the unused cu_cholesky_update kernel in favor of the FFI version.
This kernel wasn't allowed in export, so no backwards compatibility period is required. Even so, the FFI kernels were added 6 months ago.

PiperOrigin-RevId: 724359996
2025-02-07 08:48:15 -08:00
Dan Foreman-Mackey
c6e83903de Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
2025-02-06 18:27:58 -08:00
Dan Foreman-Mackey
5e915d3307 Update the sparse GPU kernels in jaxlib to use the FFI.
Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.

PiperOrigin-RevId: 724016331
2025-02-06 11:45:57 -08:00
jax authors
41993fdb24 Merge pull request #25755 from ROCm:ci_rnn_final-upstream
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
Ruturaj4
fe68eb8b25 [ROCm] Implement RNN support 2025-01-14 19:04:49 -06:00
Peter Hawkins
91ffb640a8 Use thread-safe initialization of LAPACK kernels.
Use absl::call_once instead of a GIL-protected global initialization.

In passing, also remove an unused function.

PiperOrigin-RevId: 714892175
2025-01-13 02:51:38 -08:00
Peter Hawkins
90d8f37863 Rename pybind_extension to nanobind_extension.
We have no remaining uses of pybind11 outside a GPU custom call example.

PiperOrigin-RevId: 712608834
2025-01-06 11:53:44 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
57ef7a4a59 Merge pull request #24274 from ROCm:ci_linalg_fix
PiperOrigin-RevId: 685717437
2024-10-14 08:33:33 -07:00
Ruturaj4
ee223d4004 [ROCm] jaxlib linalg fix 2024-10-13 20:25:18 -05:00
Dan Foreman-Mackey
a3bf75e442 Refactor gpusolver kernel definitions into separate build target.
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.

PiperOrigin-RevId: 673832309
2024-09-12 07:11:36 -07:00
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00