Dan Foreman-Mackey afaa3bf43c Port GPU kernels for SVD to the FFI.
Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:

1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.

2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.

Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413.

PiperOrigin-RevId: 676839182
2024-09-20 07:34:50 -07:00
..
2024-09-20 07:34:50 -07:00
2024-09-20 07:34:50 -07:00