There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.
PiperOrigin-RevId: 668381949
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.
One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).
PiperOrigin-RevId: 665829252
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.
Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.
Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!
Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.
PiperOrigin-RevId: 664680250
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.
PiperOrigin-RevId: 662024940
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.
In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.
PiperOrigin-RevId: 660831000
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.
Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
The numerical accuracy test is perfect against the reference implementation, and somewhat loose against the alt grad implementation used for testing.
PiperOrigin-RevId: 654381378
There are 2 reasons for doing this:
* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.
* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.
Also fixes: https://github.com/google/jax/issues/17422
PiperOrigin-RevId: 650621659