This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.
PiperOrigin-RevId: 662945116
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.
PiperOrigin-RevId: 658497960
This change does a few things (arguably too many):
1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.
2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.
3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.
PiperOrigin-RevId: 651107659
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.
PiperOrigin-RevId: 647343656
The typed FFI
* allows passing custom call attributes directly to backend_config= instead
of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.
PiperOrigin-RevId: 630067005
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.
PiperOrigin-RevId: 559898790
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.
PiperOrigin-RevId: 482031830
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.
PiperOrigin-RevId: 479670565
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
PiperOrigin-RevId: 479642543