Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.
PiperOrigin-RevId: 668381949
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.
One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).
PiperOrigin-RevId: 665829252
Nanobind prints a warning if a reference to a nanobind-bound function is held at process atexit() time. But there's no particularly good reason we need to hold a function reference that long anyway in this case.
PiperOrigin-RevId: 561795469
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.
PiperOrigin-RevId: 543495381
Moving some helper functions from linalg.py to hlo_helpers.py, so that we
can reuse them for more custom calls, including those in gpu_solver.
Also renamed some helper functions, e.g., _hlo_s32 -> hlo_s32, and ir_constant_i32 -> hlo_s32.
PiperOrigin-RevId: 543448560
We support polymorphism only on the batch sizes for now. The
jaxlib and C++ code support full dynamic shapes.
Also added backwards compatibility tests for the LU custom calls
for CPU, and improved the checking of LU results by checking
the invariant for the result as opposed to checking goldens.
PiperOrigin-RevId: 542852925
We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.
PiperOrigin-RevId: 510820007
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
2) Fallback lowerings now produce StableHLO ops as well.
3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).
From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.
PiperOrigin-RevId: 497978733
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
This begins the process of aligning names within MHLO with names in HLO, so we
can mechanically ensure that they remain in sync.
PiperOrigin-RevId: 454984812
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.
This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.
Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.
Fixes https://github.com/google/jax/issues/2322
PiperOrigin-RevId: 449033350
Refactoring only, no functional changes intended.
This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.
Fixes https://github.com/google/jax/issues/10474
PiperOrigin-RevId: 447076192
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.
This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.
When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.
Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.
PiperOrigin-RevId: 394705605