We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.
PiperOrigin-RevId: 676982957
Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:
1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.
2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.
Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413.
PiperOrigin-RevId: 676839182
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
In https://github.com/google/jax/issues/23687, it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic.
PiperOrigin-RevId: 675641259
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.
PiperOrigin-RevId: 673832309
Of note, I moved the logic about which algorithm to use, and when to use the batched algorithm into the kernel in order to support shape polymorphism and export.
PiperOrigin-RevId: 671853879
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 671283487
If element bitwidth changes, the ratio of bitwidth is multiplied to the 2nd minormost dim size and the leading dim in tiling. For example, we can bitcast Memref<8x128xf32> with tiling (8, 128) to Memref<16x128xi16> with tiling (16, 128).
PiperOrigin-RevId: 668619683
With proper CAPI in place these dependencies are no longer needed, llvm support needed for string ostream for string APIs.
PiperOrigin-RevId: 668476145
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.
PiperOrigin-RevId: 668381949
The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.
PiperOrigin-RevId: 666241215
It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.
PiperOrigin-RevId: 666161091
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.
One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).
PiperOrigin-RevId: 665829252