Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.
PiperOrigin-RevId: 724016331
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).
This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
PiperOrigin-RevId: 697631402
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.
PiperOrigin-RevId: 673832309
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.
PiperOrigin-RevId: 663686539
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.
PiperOrigin-RevId: 662945116
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.
PiperOrigin-RevId: 658497960
This change does a few things (arguably too many):
1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.
2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.
3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.
PiperOrigin-RevId: 651107659
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).
PiperOrigin-RevId: 545211452
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051