14 Commits

Author SHA1 Message Date
Peter Hawkins
f004bcb7b8 [JAX] Refactor JAX custom kernels to split kernel implementations from Python bindings.
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.

The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.

There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.

PiperOrigin-RevId: 394460343
2021-09-02 07:53:09 -07:00
Qiao Zhang
a93eaf3c9e Use absl::Status::message() instead of error_message().
PiperOrigin-RevId: 389810033
2021-08-09 23:44:36 -07:00
Yash Katariya
6f4937c33d In OSS #include "third_party/tensorflow/..." should be #include "tensorflow/..."
PiperOrigin-RevId: 389788858
2021-08-09 20:46:10 -07:00
Aden Grue
c368969955 Use the new "custom call status" facility to report errors in jaxlib
PiperOrigin-RevId: 389734200
2021-08-09 15:06:39 -07:00
Tom Hennigan
afbd831ec3 Avoid sharing handles across streams.
When running across 8xV100 GPUs we observed the following error:

    libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed

I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
2021-07-16 11:11:21 +00:00
Tom Hennigan
d6e56f2df9 Add source location and expression to error messages for CUDA API calls.
Before:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: operation not supported

After:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported
2021-07-15 15:42:46 +00:00
Clemens Giuliani
4981c53ac1 Add BLAS and LAPACK gpu kernels for ROCm 2020-12-16 16:00:17 +01:00
Clemens Giuliani
c128bdd90c extract the shared handle pool code from cublas and cusolver 2020-12-16 16:00:16 +01:00
Peter Hawkins
ffa198e8ef
Fix test failure on TPU. (#2088)
Update GUARDED_BY annotations to use newer ABSL_GUARDED_BY form.
2020-01-27 12:48:10 -05:00
Peter Hawkins
c5a9eba3a8
Implement batched cholesky decomposition using LAPACK/Cusolver (#1956)
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.

Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
2020-01-07 10:56:15 -05:00
Skye Wanderman-Milne
7a154f71bc
Fix jaxlib build by not exposing nvcc to pybind11. (#1819) 2019-12-05 18:59:29 -08:00
Peter Hawkins
34dfbc8ae6
Add error checking to PRNG CUDA kernel. (#1760)
Refactor error checking code into a common helper library.
2019-11-25 11:48:45 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Peter Hawkins
5ac356d680 Add support for batched triangular solve and LU decomposition on GPU using cuBlas. 2019-08-08 13:34:53 -04:00