1371 Commits

Author SHA1 Message Date
Leello Tadesse Dadi
f9a246ac19 schur lapack wrapper 2021-09-29 14:29:52 +02:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00
Peter Hawkins
f1f0aefc59 Further split up cuda_prng and cuda_lu_pivot kernels to avoid exposing ABSL code to NVCC.
With this change, we are careful not to include any ABSL-including .cc files in cuda_library rules.

PiperOrigin-RevId: 394544751
2021-09-02 14:35:46 -07:00
Peter Hawkins
f004bcb7b8 [JAX] Refactor JAX custom kernels to split kernel implementations from Python bindings.
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.

The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.

There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.

PiperOrigin-RevId: 394460343
2021-09-02 07:53:09 -07:00
yashkatariya
be824a792e Update files after new jaxlib release 0.1.71 2021-09-01 10:43:20 -07:00
Aden Grue
6b00b44807 Move all Abseil dependencies out of jaxlib CUDA libraries
These were breaking the build with CUDA 10.2

PiperOrigin-RevId: 391875083
2021-08-19 16:55:37 -07:00
Reza Rahimi
16a110e4ff fix custom_call_status for rocm 2021-08-16 04:39:47 +00:00
Jake VanderPlas
f8081a9a52 [sparse] fix GPU translation rule for coo/csr matmat 2021-08-10 10:13:29 -07:00
Qiao Zhang
a93eaf3c9e Use absl::Status::message() instead of error_message().
PiperOrigin-RevId: 389810033
2021-08-09 23:44:36 -07:00
Qiao Zhang
2afba31f71 Fix return value for MakeBatchPointers.
PiperOrigin-RevId: 389795975
2021-08-09 21:44:46 -07:00
Yash Katariya
6f4937c33d In OSS #include "third_party/tensorflow/..." should be #include "tensorflow/..."
PiperOrigin-RevId: 389788858
2021-08-09 20:46:10 -07:00
Yash Katariya
bf967d88d8 Upgrade versions after jaxlib release
PiperOrigin-RevId: 389753047
2021-08-09 16:37:44 -07:00
Aden Grue
c368969955 Use the new "custom call status" facility to report errors in jaxlib
PiperOrigin-RevId: 389734200
2021-08-09 15:06:39 -07:00
Aden Grue
d6df61c305 Fix the move constructor for Handle
PiperOrigin-RevId: 389212536
2021-08-06 10:55:30 -07:00
jax authors
df103f7e66 Merge pull request #7493 from yashk2810:release
PiperOrigin-RevId: 388811438
2021-08-04 16:36:52 -07:00
Yash Katariya
b5b44b0639 Remove cublas_header by default for anything under jaxlib/*
PiperOrigin-RevId: 388802894
2021-08-04 15:53:45 -07:00
yashkatariya
677eed49a3 Add Cusolver dep 2021-08-04 12:13:05 -07:00
yashkatariya
72da78c64c Add cublas headers. cublas_kernels BUILD target has that dependency 2021-08-04 12:03:28 -07:00
Peter Hawkins
2c3b647939 Add missing <map> #include to jaxlib.
PiperOrigin-RevId: 388350144
2021-08-02 18:27:14 -07:00
Sergei Lebedev
2a994bdb02 Type stubs for jaxlib.xla_extension no longer use -stubs suffix
PEP-561 does not specify whether subpackages of a non-stub-only-package
could use the -stubs suffix. setuptools seems to allow that, yet mypy fails
to resolve the subpackage with a -stubs suffix.

This commit makes jaxlib.xla_extension a ~normal package with a toplevel
__init__.pyi.
2021-08-02 14:31:11 +01:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
jax authors
29187a3317 Merge pull request #7315 from ROCmSoftwarePlatform:fix_pr_7306_rocm
PiperOrigin-RevId: 385566677
2021-07-19 09:01:05 -07:00
Cloud Han
6d84e02724 workaround compiling issue on Windows when cuda version < 11.0 2021-07-18 23:01:08 +08:00
Reza Rahimi
ee08acd046 update rocblas because of PR-7306 2021-07-17 08:04:56 +00:00
jax authors
7bd5fe54fd Merge pull request #7307 from cloudhan:missing_lib
PiperOrigin-RevId: 385170297
2021-07-16 10:08:32 -07:00
jax authors
1044401e50 Merge pull request #7306 from tomhennigan:changelist/385115540
PiperOrigin-RevId: 385169180
2021-07-16 10:03:42 -07:00
jax authors
8dc2de552e Merge pull request #7273 from tomhennigan:changelist/384520796
PiperOrigin-RevId: 385126655
2021-07-16 05:20:30 -07:00
Cloud Han
cf7298d238 cusparseGetErrorString is external symbol, without cusparse_lib as dependency, linker error 2021-07-16 19:55:55 +08:00
Tom Hennigan
afbd831ec3 Avoid sharing handles across streams.
When running across 8xV100 GPUs we observed the following error:

    libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed

I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
2021-07-16 11:11:21 +00:00
jax authors
6aa20d8f8f Merge pull request #7294 from hawkinsp:py36
PiperOrigin-RevId: 384994957
2021-07-15 13:19:23 -07:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Tom Hennigan
afa0d5725b Compute gtsv2 buffer size ahead of time and pass in to kernel.
A user reported that with their Quadro M4000 GPU (Driver: 460.56) tridiagonal_solve was throwing an "unsupported operation" error. I improved the logging (also included in this patch) and tracked it down to:

jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported

I had some challenges trying to figure out when async malloc was supported (it seems that for cards with compute <6 it fails) but have found an alternative approach where we compute the buffer size ahead of time and ask XLA to allocate. This is preferred for sure (although requires passing null pointers into cusparseSgtsv2_bufferSizeExt which seems to work today but I guess might change in future cuSPARSE releases).
2021-07-15 16:06:23 +00:00
Tom Hennigan
d6e56f2df9 Add source location and expression to error messages for CUDA API calls.
Before:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: operation not supported

After:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported
2021-07-15 15:42:46 +00:00
Qiao Zhang
82e74959fe Update changelog for jaxlib-0.1.69. 2021-07-12 12:06:41 -07:00
Qiao Zhang
a22841b6bb Bump jaxlib ver to 0.1.68. 2021-06-23 12:37:56 -07:00
jax authors
28977761d5 Merge pull request #6849 from tomhennigan:changelist/376000598
PiperOrigin-RevId: 381010658
2021-06-23 05:46:01 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
erwin coumans
f0e55e3ce2 move #endif so that Windows doesn't have GetXCR0EAX defined twice
(erroring out)
2021-06-04 17:26:27 -07:00
Tom Hennigan
ffac40a2c0 Add lax.linalg.tridiagonal_solve(..), lowering to cusparse_gtsv2<T>() on GPU.
Fixes #6830.
2021-06-02 13:49:02 +00:00
Peter Hawkins
7a099bf7ee Mark PyInit_cpu_feature_guard symbol as exported.
Otherwise it may get stripped out by -fvisibility=hidden.

PiperOrigin-RevId: 374739406
2021-05-19 15:12:00 -07:00
Peter Hawkins
b03466a390 Remove jax/BUILD file.
This Bazel build file is unused; we only use Bazel to build jaxlib, which does not include any files from jax/.

PiperOrigin-RevId: 374709682
2021-05-19 12:55:34 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Skye Wanderman-Milne
63dbb99a66 Update README, etc. for jaxlib 0.1.67 release 2021-05-17 17:48:46 -07:00
Qiao Zhang
528d5bbb11 Update README etc for jaxlib 0.1.66 release. 2021-05-11 16:49:32 -07:00
Peter Hawkins
01d6e32c7f Add version constraints to flatbuffers versions.
Require 1.12 or newer, because we've only tested 1.12 and 2.0.
Require less than 3.0, because flatbuffers uses semantic versioning and version 3.0 would mean an incompatible change has been made.
2021-05-10 20:58:22 -04:00
Qiao Zhang
1509f995ed Update flatbuffer ver2 EndVector usage.
PiperOrigin-RevId: 373039656
2021-05-10 16:45:19 -07:00
Qiao Zhang
fa6211af8d Gate on CUSPARSE_VERSION instead of CUDART_VERSION in jaxlib/cusparse.cc.
PiperOrigin-RevId: 373032029
2021-05-10 16:04:31 -07:00
Peter Hawkins
c983d3c660 Bundle libdevice.10.bc with jaxlib wheels.
libdevice.10.bc is a redistributable part of the CUDA SDK.

This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
2021-04-29 10:26:03 -04:00