It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.
PiperOrigin-RevId: 666161091
The biggest change here is that we now ignore the `info` parameter that is returned by `getrf`. In the previous implementation, we would return an error in the batched implementation, or set the relevant matrix entries to NaN in the non-batched version if `info != 0`. But, since info is only used for shape checking (see LAPACK, cuBLAS and cuSolver docs), I argue that we will never see `info != 0`, because we're including all the shape checks in the kernel already.
PiperOrigin-RevId: 665307128
Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr.
As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses.
PiperOrigin-RevId: 664803225
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.
PiperOrigin-RevId: 663686539
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.
PiperOrigin-RevId: 662945116
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.
In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.
PiperOrigin-RevId: 660831000
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.
PiperOrigin-RevId: 658497960
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.
PiperOrigin-RevId: 658011228
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.
PiperOrigin-RevId: 651166306
This change does a few things (arguably too many):
1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.
2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.
3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.
PiperOrigin-RevId: 651107659
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```
It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.
The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.
PiperOrigin-RevId: 647993991
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.
PiperOrigin-RevId: 647343656
The typed FFI
* allows passing custom call attributes directly to backend_config= instead
of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.
PiperOrigin-RevId: 630067005
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target
The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.
Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.
PiperOrigin-RevId: 623285804
A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.
PiperOrigin-RevId: 605417225
Autotuning is not compatible with graph capture because it requires synchronizing.
We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)
PiperOrigin-RevId: 602436960
The autotuner runs a series of benchmarks to determine the best configuration
for a Triton kernel. However, if it encounters a config that does not fit in
shared memory it throws an error and stops. I this eventuality it should just
continue.
PiperOrigin-RevId: 600730507
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().
PiperOrigin-RevId: 582279549