45 Commits

Author SHA1 Message Date
Paweł Paruzel
414eb90f5b Activate Householder Product to XLA's FFI
PiperOrigin-RevId: 670196460
2024-09-02 06:19:01 -07:00
Paweł Paruzel
3c6103f2df Activate Eigenvalue Decompositions to XLA's FFI
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.

PiperOrigin-RevId: 668381949
2024-08-28 03:53:49 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
Paweł Paruzel
354293da48 Activate Singular Value Decomposition to XLA's FFI
PiperOrigin-RevId: 662436635
2024-08-13 02:41:57 -07:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
George Necula
24f9011d49 Align the custom_call implementation in mlir and hlo_helpers.
PiperOrigin-RevId: 562385810
2023-09-03 13:07:22 -07:00
Peter Hawkins
da097e70ec Don't hold a reference to _lapack.initialize().
Nanobind prints a warning if a reference to a nanobind-bound function is held at process atexit() time. But there's no particularly good reason we need to hold a function reference that long anyway in this case.

PiperOrigin-RevId: 561795469
2023-08-31 16:46:04 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
George Necula
c6a60054b9 [shape_poly] linalg.schur: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543533821
2023-06-26 13:59:01 -07:00
George Necula
a91412e1e7 [shape_poly] linalg.triangular_solve: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543506845
2023-06-26 12:13:12 -07:00
George Necula
ea0e50f765 [shape_poly] Refactor support for dynamic shapes for linalg.eig and linalg.eigh
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.

PiperOrigin-RevId: 543495381
2023-06-26 11:31:31 -07:00
George Necula
aadcec2b1b [shape_poly] Refactor some older support for shape polymorphism for linalg.
Moving some helper functions from linalg.py to hlo_helpers.py, so that we
can reuse them for more custom calls, including those in gpu_solver.

Also renamed some helper functions, e.g., _hlo_s32 -> hlo_s32, and ir_constant_i32 -> hlo_s32.

PiperOrigin-RevId: 543448560
2023-06-26 08:40:22 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
George Necula
bbc6f30693 [shape_poly] linalg.lu: for shape polymorphism for native serialization on CPU.
We support polymorphism only on the batch sizes for now. The
jaxlib and C++ code support full dynamic shapes.

Also added backwards compatibility tests for the LU custom calls
for CPU, and improved the checking of LU results by checking
the invariant for the result as opposed to checking goldens.

PiperOrigin-RevId: 542852925
2023-06-23 07:25:24 -07:00
George Necula
f02d122366 [shape_poly] linalg.cholesky: add support for shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 542497269
2023-06-22 02:21:24 -07:00
George Necula
d940db609d [shape_poly] linalg.qr: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 542489216
2023-06-22 01:40:08 -07:00
George Necula
92288d3071 [shape_poly] linalg.svd: shape polymorphism for native serialization on CPU
PiperOrigin-RevId: 542483203
2023-06-22 01:05:59 -07:00
George Necula
3adfe321b0 [shape_poly] linalg.eig: shape polymorphism with native serialization on CPU
The backwards compatibility tests to be added separately.

PiperOrigin-RevId: 541122069
2023-06-16 23:59:18 -07:00
George Necula
a2ac510dc3 [shape_poly] Add support for dynamic shapes for eigh
We can only handle dynamic sizes for the batch dimensions for now.

PiperOrigin-RevId: 529001830
2023-05-02 23:27:59 -07:00
Eugene Burmako
f337c00ed5 Remove *_mhlo compatibility shims from jaxlib
We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.

PiperOrigin-RevId: 510820007
2023-02-19 09:03:14 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Anish Tondwalkar
aeabfc4b37 [mhlo] ConstOp -> ConstantOp, a la HLO
This begins the process of aligning names within MHLO with names in HLO, so we
can mechanically ensure that they remain in sync.

PiperOrigin-RevId: 454984812
2022-06-14 16:33:24 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
909c0328b0 Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes https://github.com/google/jax/issues/2322

PiperOrigin-RevId: 449033350
2022-05-16 12:59:57 -07:00
Peter Hawkins
883cf2b1e9 Refactor custom call building code in jaxlib to use a helper function.
Refactoring only, no functional changes intended.

This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.

Fixes https://github.com/google/jax/issues/10474

PiperOrigin-RevId: 447076192
2022-05-06 14:51:24 -07:00
Matthew Johnson
0c5864a220 add xla_client._version checks for mhlo.ConstOp signature
fix break from 0cf08d0c6841332240cae873e4b4cf9a9b313373
2022-05-04 09:54:06 -07:00
jax authors
0cf08d0c68 Integrate LLVM at llvm/llvm-project@46cc04de34
Updates LLVM usage to match
[46cc04de341b](https://github.com/llvm/llvm-project/commit/46cc04de341b)

PiperOrigin-RevId: 446430294
2022-05-04 05:31:41 -07:00
Jake VanderPlas
c6343ddf8e jax.scipy.linalg.schur: error on 16-bit floats
Fixes https://github.com/google/jax/issues/10530

PiperOrigin-RevId: 446279906
2022-05-03 13:47:44 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
6c1461b52b [MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
bc658e7456 [MHLO] Add direct MHLO lowerings for most linear algebra kernels.
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Aden Grue
8884ce5b98 Migrate 'jaxlib' CPU custom-calls to the status-returning API
PiperOrigin-RevId: 438165260
2022-03-29 17:14:14 -07:00
Leello Tadesse Dadi
f9a246ac19 schur lapack wrapper 2021-09-29 14:29:52 +02:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00