This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.
Changes:
1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
3. Add `to_tangent_type` calls in various other places they're missing.
4. Remove non-support for float0 in custom deriviatives?
5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.
PiperOrigin-RevId: 668381949
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.
One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).
PiperOrigin-RevId: 665829252
The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.
Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.
Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!
Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.
PiperOrigin-RevId: 664680250
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.
PiperOrigin-RevId: 662024940
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.
In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.
PiperOrigin-RevId: 660831000
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.
To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.
Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.
This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).
Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.
PiperOrigin-RevId: 561042402
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.
PiperOrigin-RevId: 543495381
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:
```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```
Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```
We support polymorphism only on the batch sizes for now. The
jaxlib and C++ code support full dynamic shapes.
Also added backwards compatibility tests for the LU custom calls
for CPU, and improved the checking of LU results by checking
the invariant for the result as opposed to checking goldens.
PiperOrigin-RevId: 542852925