57 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
2ce88c950a Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.)

This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).

Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.

PiperOrigin-RevId: 730928808
2025-02-25 10:04:29 -08:00
Jan Naumann
e03fe3a06d Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
2025-02-22 15:24:57 +01:00
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Peter Hawkins
91ffb640a8 Use thread-safe initialization of LAPACK kernels.
Use absl::call_once instead of a GIL-protected global initialization.

In passing, also remove an unused function.

PiperOrigin-RevId: 714892175
2025-01-13 02:51:38 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Peter Hawkins
90d8f37863 Rename pybind_extension to nanobind_extension.
We have no remaining uses of pybind11 outside a GPU custom call example.

PiperOrigin-RevId: 712608834
2025-01-06 11:53:44 -08:00
Peter Hawkins
61dd041225 Suppress MSAN warnings from SVD that are showing up in CI.
In our MSAN CI, the copy of LAPACK we use is not MSAN-instrumented, leading to false positives. Suppress those false-positives via annotations.

PiperOrigin-RevId: 712607044
2025-01-06 11:49:05 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
Paweł Paruzel
23fdb91252 Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
2024-10-14 06:46:42 -07:00
Paweł Paruzel
ec68d420fe Port Tridiagonal Reduction to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685679646
2024-10-14 06:02:59 -07:00
Paweł Paruzel
2082662bb1 Port Hessenberg Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 671283487
2024-09-05 01:59:32 -07:00
Peter Hawkins
1ab3119d43 Add some msan suppressions to the LAPACK symmetric eigendecomposition FFI call.
This fixes some msan false positives in our CI, since we do not msan-instrument Fortran code.

PiperOrigin-RevId: 669385248
2024-08-30 11:12:45 -07:00
Paweł Paruzel
4342c0c0f3 Determine LAPACK workspace during Householder Product Kernel runtime
Workspace dependency was removed, and the info parameter is ignored now.

PiperOrigin-RevId: 669246058
2024-08-30 02:06:16 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Paweł Paruzel
4786930a4c Determine LAPACK workspace during Eigenvalue Kernels runtime
PiperOrigin-RevId: 666285759
2024-08-22 04:09:34 -07:00
Paweł Paruzel
a72d46c549 Ignore LAPACK info parameter for QR Factorization
The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.

PiperOrigin-RevId: 666241215
2024-08-22 01:38:38 -07:00
Dan Foreman-Mackey
30d54ec6ff Refactor FFI shape inference functions to include dimension check.
Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr.

As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses.

PiperOrigin-RevId: 664803225
2024-08-19 07:41:28 -07:00
Paweł Paruzel
acacf8884e Determine LAPACK workspace during QR Factorization Kernel runtime
PiperOrigin-RevId: 663641199
2024-08-16 01:20:50 -07:00
Paweł Paruzel
5fc992e5e1 Determine LAPACK workspaces during SVD kernel runtime
The SVD kernel implementation used to require workspace shapes to be determined prior to the custom call on the JAX's side. The new FFI kernels need not demand these shapes to be specified anymore. They are evaluated during kernel runtime.

PiperOrigin-RevId: 662413273
2024-08-13 01:17:44 -07:00
Paweł Paruzel
b2a469b361 Port Eigenvalue Decompositions to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
2024-08-05 03:18:13 -07:00
Dan Foreman-Mackey
618754d829 Move some common helper functions from lapack_kernels to ffi_helpers.
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.

PiperOrigin-RevId: 658002501
2024-07-31 07:38:33 -07:00
Dan Foreman-Mackey
ff4e0b1214 Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors.
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.

PiperOrigin-RevId: 657186057
2024-07-29 06:59:44 -07:00
Paweł Paruzel
54fe6e68a0 Port Triangular Solve to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 655484166
2024-07-24 02:15:41 -07:00
Paweł Paruzel
5cce394428 Port Householder Product to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 651691430
2024-07-12 01:36:41 -07:00
Paweł Paruzel
86ab50d92f Port QR Factorization to XLA's FFI
PiperOrigin-RevId: 651396166
2024-07-11 07:07:39 -07:00
Dan Foreman-Mackey
33a9db3943 Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
2024-07-10 15:09:45 -07:00
Eugene Zhulenev
d49a0c5a63 [jax] Remove dead code from JAX custom calls defined as FFI handlers
PiperOrigin-RevId: 651025363
2024-07-10 08:11:12 -07:00
Eugene Zhulenev
1e03917c43 [xla:ffi] Use lazy decoding for Buffer<dtype,rank>
name                old cpu/op   new cpu/op   delta
BM_AnyBufferArgX1   11.0ns ± 3%  11.2ns ±10%   +1.76%  (p=0.000 n=67+69)
BM_AnyBufferArgX4   12.4ns ± 3%  12.4ns ± 4%   -0.31%  (p=0.006 n=69+69)
BM_BufferArgX1      12.5ns ± 1%  11.1ns ± 4%  -11.20%  (p=0.000 n=62+76)
BM_BufferArgX4      19.1ns ± 1%  14.4ns ± 4%  -24.84%  (p=0.000 n=64+73)
BM_BufferArgX8      36.0ns ± 5%  20.3ns ± 4%  -43.59%  (p=0.000 n=79+75)
BM_TupleOfI32Attrs  66.4ns ± 1%  66.4ns ± 2%   -0.03%  (p=0.000 n=66+72)

PiperOrigin-RevId: 650691450
2024-07-09 11:07:25 -07:00
Paweł Paruzel
4e1a66ea21 Avoid throwing exceptions in LAPACK kernel code
PiperOrigin-RevId: 650569943
2024-07-09 03:57:50 -07:00
Paweł Paruzel
532be68461 Port Singular Value Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 650212574
2024-07-08 05:19:53 -07:00
Dan Foreman-Mackey
98b87540a7 Avoid throwing exceptions in LAPACK CPU kernels.
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.

PiperOrigin-RevId: 647348889
2024-06-27 09:41:07 -07:00
Paweł Paruzel
63aab133f1 Port LU Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 644845277
2024-06-19 17:31:25 -07:00
Paweł Paruzel
3d39b6e752 Port Cholesky Factorization to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 642954763
2024-06-13 05:44:36 -07:00
Paweł Paruzel
5fcd50b7fa Refactor kernel function assigment
PiperOrigin-RevId: 641255192
2024-06-07 08:20:31 -07:00
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
George Necula
d92f4ae157 Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
PiperOrigin-RevId: 628362293
2024-04-26 04:14:34 -07:00
jax authors
9db5e693eb Reverts 6bfbb4593a42fced91ba50de47271af425c74c20
PiperOrigin-RevId: 628035616
2024-04-25 04:53:22 -07:00
George Necula
6bfbb4593a Remove old ducc_fft custom call.
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.

We need to keep dynamic_ducc_fft a little bit longer (May 2024).

PiperOrigin-RevId: 627981921
2024-04-25 00:29:11 -07:00
jax authors
16b29a6930 Merge pull request #19288 from pearu:pearu/int32-overflow
PiperOrigin-RevId: 608701959
2024-02-20 12:43:16 -08:00
Pearu Peterson
3fa1033ac1 Prevent silent overflow in lapack worker size calculations.
Add -fexceptions to building lapack_kernels
2024-02-20 11:04:06 +02:00
Shashank Viswanadha
bd46e5c960 Add nb::arg to nanobind definitions to generate better python annotations.
PiperOrigin-RevId: 586721759
2023-11-30 10:39:28 -08:00
Shashank Viswanadha
350b7c56b8 Add python stub files for jaxlib/cpu C++ Python extensions.
PiperOrigin-RevId: 585990748
2023-11-28 08:45:24 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
Antonio Sanchez
873cffc776 Use TSL's import of DUCC.
This is a necessary first step before adding DUCC support to XLA,
otherwise the JAX tests in the XLA repo pull from JAX's copy,
which has slightly different build rules.

PiperOrigin-RevId: 576880208
2023-10-26 08:26:56 -07:00
Peter Hawkins
dbf13252f0 Copybara import of the project:
--
3905d6123bdc22f505934242363fda426c99c4cf by Peter Hawkins <phawkins@google.com>:

Update flatbuffers.

Use upstream flatbuffer bazel scripts, with a couple of small patches to fix:
* https://github.com/google/flatbuffers/issues/8087 (remove npm references)
* https://github.com/google/flatbuffers/pull/8088 (fix flatc build failure due to main() removal by linker)

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17502 from hawkinsp:fb 3905d6123bdc22f505934242363fda426c99c4cf
PiperOrigin-RevId: 563543954
2023-09-07 14:27:25 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Antonio Sanchez
a600020346 Update ducc to commit: 2b2cead005e08d2632478e831d7f45da754162dc
NOTE: this version of DUCC has a breaking change, where the fft.h header
no longer contains the definitions of many fft functions - instead they exist
within fft1d_impl.h and fftnd_impl.h.
PiperOrigin-RevId: 554567641
2023-08-07 13:06:43 -07:00
George Necula
b9c0658fcf Add support for dynamic shapes to jax.fft.
The idea is that we take all the values that can contain dimension sizes
from the descriptor (shape, strides_in, strides_out) and we pass them as
1-d tensor operands. We also pass as an operand the output_shape, so that
we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to
tell the shape refinement how to compute the shape of the result.

We keep the old descriptor and the ducc_fft registration for the old
C++ custom targets for backwards compatibility (for 6 months). That behavior
is tested by back_compat_test.py.

The one downside of this implementation is that it moves some of the
ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This
was necessary because that code computes with dimensions that are now
dynamic. In JAX we have support for evaluating dynamic shapes and turning
them into 1-d tensors.

Also added backwards compatibility test for dynamic_ducc_fft and kept the
old test for ducc_fft.

PiperOrigin-RevId: 541168692
2023-06-17 04:50:54 -07:00