This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 671283487
The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.
PiperOrigin-RevId: 666241215
Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr.
As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses.
PiperOrigin-RevId: 664803225
The SVD kernel implementation used to require workspace shapes to be determined prior to the custom call on the JAX's side. The new FFI kernels need not demand these shapes to be specified anymore. They are evaluated during kernel runtime.
PiperOrigin-RevId: 662413273
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 659492696
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.
PiperOrigin-RevId: 658002501
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.
PiperOrigin-RevId: 657186057
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 655484166
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 651691430
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.
PiperOrigin-RevId: 651166306
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 650212574
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.
PiperOrigin-RevId: 647348889
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 644845277
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 642954763
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.
PiperOrigin-RevId: 638132572
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.
We need to keep dynamic_ducc_fft a little bit longer (May 2024).
PiperOrigin-RevId: 627981921
This is a necessary first step before adding DUCC support to XLA,
otherwise the JAX tests in the XLA repo pull from JAX's copy,
which has slightly different build rules.
PiperOrigin-RevId: 576880208
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.
PiperOrigin-RevId: 559898790
NOTE: this version of DUCC has a breaking change, where the fft.h header
no longer contains the definitions of many fft functions - instead they exist
within fft1d_impl.h and fftnd_impl.h.
PiperOrigin-RevId: 554567641
The idea is that we take all the values that can contain dimension sizes
from the descriptor (shape, strides_in, strides_out) and we pass them as
1-d tensor operands. We also pass as an operand the output_shape, so that
we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to
tell the shape refinement how to compute the shape of the result.
We keep the old descriptor and the ducc_fft registration for the old
C++ custom targets for backwards compatibility (for 6 months). That behavior
is tested by back_compat_test.py.
The one downside of this implementation is that it moves some of the
ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This
was necessary because that code computes with dimensions that are now
dynamic. In JAX we have support for evaluating dynamic shapes and turning
them into 1-d tensors.
Also added backwards compatibility test for dynamic_ducc_fft and kept the
old test for ducc_fft.
PiperOrigin-RevId: 541168692
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934