Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.
The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.
There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.
PiperOrigin-RevId: 394460343
When running across 8xV100 GPUs we observed the following error:
libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed
I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.
Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.
When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.