55 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00
Dan Foreman-Mackey
154e4506c0 Some lax.linalg housekeeping.
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.

PiperOrigin-RevId: 725223882
2025-02-10 08:27:18 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
Paweł Paruzel
d474feda9e Activate Tridiagonal Reduction to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.

PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
Paweł Paruzel
9081e85d68 Activate Schur Decomposition to XLA's FFI
PiperOrigin-RevId: 703484916
2024-12-06 06:49:53 -08:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Paweł Paruzel
6e9a53690c Activate Hessenberg Decomposition to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Hessenberg Decomposition.

PiperOrigin-RevId: 681047625
2024-10-01 09:20:06 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
Paweł Paruzel
414eb90f5b Activate Householder Product to XLA's FFI
PiperOrigin-RevId: 670196460
2024-09-02 06:19:01 -07:00
Paweł Paruzel
3c6103f2df Activate Eigenvalue Decompositions to XLA's FFI
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.

PiperOrigin-RevId: 668381949
2024-08-28 03:53:49 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
Paweł Paruzel
354293da48 Activate Singular Value Decomposition to XLA's FFI
PiperOrigin-RevId: 662436635
2024-08-13 02:41:57 -07:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
George Necula
24f9011d49 Align the custom_call implementation in mlir and hlo_helpers.
PiperOrigin-RevId: 562385810
2023-09-03 13:07:22 -07:00
Peter Hawkins
da097e70ec Don't hold a reference to _lapack.initialize().
Nanobind prints a warning if a reference to a nanobind-bound function is held at process atexit() time. But there's no particularly good reason we need to hold a function reference that long anyway in this case.

PiperOrigin-RevId: 561795469
2023-08-31 16:46:04 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
George Necula
c6a60054b9 [shape_poly] linalg.schur: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543533821
2023-06-26 13:59:01 -07:00
George Necula
a91412e1e7 [shape_poly] linalg.triangular_solve: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543506845
2023-06-26 12:13:12 -07:00
George Necula
ea0e50f765 [shape_poly] Refactor support for dynamic shapes for linalg.eig and linalg.eigh
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.

PiperOrigin-RevId: 543495381
2023-06-26 11:31:31 -07:00
George Necula
aadcec2b1b [shape_poly] Refactor some older support for shape polymorphism for linalg.
Moving some helper functions from linalg.py to hlo_helpers.py, so that we
can reuse them for more custom calls, including those in gpu_solver.

Also renamed some helper functions, e.g., _hlo_s32 -> hlo_s32, and ir_constant_i32 -> hlo_s32.

PiperOrigin-RevId: 543448560
2023-06-26 08:40:22 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
George Necula
bbc6f30693 [shape_poly] linalg.lu: for shape polymorphism for native serialization on CPU.
We support polymorphism only on the batch sizes for now. The
jaxlib and C++ code support full dynamic shapes.

Also added backwards compatibility tests for the LU custom calls
for CPU, and improved the checking of LU results by checking
the invariant for the result as opposed to checking goldens.

PiperOrigin-RevId: 542852925
2023-06-23 07:25:24 -07:00
George Necula
f02d122366 [shape_poly] linalg.cholesky: add support for shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 542497269
2023-06-22 02:21:24 -07:00
George Necula
d940db609d [shape_poly] linalg.qr: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 542489216
2023-06-22 01:40:08 -07:00
George Necula
92288d3071 [shape_poly] linalg.svd: shape polymorphism for native serialization on CPU
PiperOrigin-RevId: 542483203
2023-06-22 01:05:59 -07:00
George Necula
3adfe321b0 [shape_poly] linalg.eig: shape polymorphism with native serialization on CPU
The backwards compatibility tests to be added separately.

PiperOrigin-RevId: 541122069
2023-06-16 23:59:18 -07:00
George Necula
a2ac510dc3 [shape_poly] Add support for dynamic shapes for eigh
We can only handle dynamic sizes for the batch dimensions for now.

PiperOrigin-RevId: 529001830
2023-05-02 23:27:59 -07:00
Eugene Burmako
f337c00ed5 Remove *_mhlo compatibility shims from jaxlib
We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.

PiperOrigin-RevId: 510820007
2023-02-19 09:03:14 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Anish Tondwalkar
aeabfc4b37 [mhlo] ConstOp -> ConstantOp, a la HLO
This begins the process of aligning names within MHLO with names in HLO, so we
can mechanically ensure that they remain in sync.

PiperOrigin-RevId: 454984812
2022-06-14 16:33:24 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
909c0328b0 Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes https://github.com/google/jax/issues/2322

PiperOrigin-RevId: 449033350
2022-05-16 12:59:57 -07:00
Peter Hawkins
883cf2b1e9 Refactor custom call building code in jaxlib to use a helper function.
Refactoring only, no functional changes intended.

This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.

Fixes https://github.com/google/jax/issues/10474

PiperOrigin-RevId: 447076192
2022-05-06 14:51:24 -07:00
Matthew Johnson
0c5864a220 add xla_client._version checks for mhlo.ConstOp signature
fix break from 0cf08d0c6841332240cae873e4b4cf9a9b313373
2022-05-04 09:54:06 -07:00
jax authors
0cf08d0c68 Integrate LLVM at llvm/llvm-project@46cc04de34
Updates LLVM usage to match
[46cc04de341b](https://github.com/llvm/llvm-project/commit/46cc04de341b)

PiperOrigin-RevId: 446430294
2022-05-04 05:31:41 -07:00
Jake VanderPlas
c6343ddf8e jax.scipy.linalg.schur: error on 16-bit floats
Fixes https://github.com/google/jax/issues/10530

PiperOrigin-RevId: 446279906
2022-05-03 13:47:44 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00