Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.
This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.
When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.
Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.
PiperOrigin-RevId: 394705605
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.
The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.
There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.
PiperOrigin-RevId: 394460343
PEP-561 does not specify whether subpackages of a non-stub-only-package
could use the -stubs suffix. setuptools seems to allow that, yet mypy fails
to resolve the subpackage with a -stubs suffix.
This commit makes jaxlib.xla_extension a ~normal package with a toplevel
__init__.pyi.
When running across 8xV100 GPUs we observed the following error:
libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed
I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
A user reported that with their Quadro M4000 GPU (Driver: 460.56) tridiagonal_solve was throwing an "unsupported operation" error. I improved the logging (also included in this patch) and tracked it down to:
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported
I had some challenges trying to figure out when async malloc was supported (it seems that for cards with compute <6 it fails) but have found an alternative approach where we compute the buffer size ahead of time and ask XLA to allocate. This is preferred for sure (although requires passing null pointers into cusparseSgtsv2_bufferSizeExt which seems to work today but I guess might change in future cuSPARSE releases).
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
Require 1.12 or newer, because we've only tested 1.12 and 2.0.
Require less than 3.0, because flatbuffers uses semantic versioning and version 3.0 would mean an incompatible change has been made.
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.