42 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00
Robert Suderman
499a4e733c Expose ml_program dialect for MLIR builder
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
2022-06-28 20:29:41 +00:00
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
Peter Hawkins
d56601a896 Include mlir.transforms and mlir.passmanager in jaxlib BUILD.
These increase the binary size of jaxlib only a negligible amount, and allow running passes like canonicalization from Python.
2022-05-16 13:00:22 +00:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00
Peter Hawkins
83dee8f81e Make jaxlib extension libraries Bazel deps of //jaxlib.
Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way.

It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want.

PiperOrigin-RevId: 446802592
2022-05-05 13:30:06 -07:00
Aart Bik
c1261ccd27 Adds a wrapper to sparse tensor dialect, as part of an
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.

PiperOrigin-RevId: 443438043
2022-04-21 11:48:44 -07:00
jax authors
cf9a900d78 Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
jax authors
6c45969fe4 Integrate LLVM at llvm/llvm-project@eb27da7dec
Updates LLVM usage to match
[eb27da7dec67](https://github.com/llvm/llvm-project/commit/eb27da7dec67)

PiperOrigin-RevId: 432199388
2022-03-03 08:24:39 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Cloud Han
317edcdacd fix mlir capi dll building and linking 2021-11-25 00:07:25 +08:00
Peter Hawkins
11f6c535ae Add MLIR:Python bindings to jaxlib build.
PiperOrigin-RevId: 407657331
2021-11-04 13:29:58 -07:00
Yash Katariya
ac0796048f Move cuda .py files to :gpu_support so that if :gpu_support is not present, then internal jaxlib will act like a CPU jaxlib even if --config=cuda is specified.
PiperOrigin-RevId: 403170945
2021-10-14 13:20:01 -07:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00
Peter Hawkins
f004bcb7b8 [JAX] Refactor JAX custom kernels to split kernel implementations from Python bindings.
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.

The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.

There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.

PiperOrigin-RevId: 394460343
2021-09-02 07:53:09 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Peter Hawkins
c983d3c660 Bundle libdevice.10.bc with jaxlib wheels.
libdevice.10.bc is a redistributable part of the CUDA SDK.

This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
2021-04-29 10:26:03 -04:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
Andreas Hoenselaar
a19098d462 Reimplement as JAX Primitive 2021-04-03 14:11:36 -07:00
jax authors
a8518769a2 Merge pull request #5115 from inailuig:rocm-gpukernels
PiperOrigin-RevId: 348077827
2020-12-17 14:01:04 -08:00
Clemens Giuliani
4981c53ac1 Add BLAS and LAPACK gpu kernels for ROCm 2020-12-16 16:00:17 +01:00
Peter Hawkins
21914db261 Export LICENSE.txt file in jaxlib wheels. 2020-12-11 10:20:03 -05:00
Peter Hawkins
c23edb805e Include a LICENSE.txt in jaxlib wheels.
PiperOrigin-RevId: 346988918
2020-12-11 06:45:15 -08:00
Peter Hawkins
c06ead6b04 Change jaxlib build rules to build a wheel, rather than writing output to the source directory. 2020-11-20 11:47:00 -05:00
Cloud Han
a6acce58e0 Build on Windows
1. Build on Windows

2. Fix OverflowError

    When calling `key = random.PRNGKey(0)` OverflowError: Python int too
    large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
    from python int to int32.

3. fix file path in regex of errors_test

4. handle ValueError of os.path.commonpath
2020-11-19 23:33:06 +08:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Peter Hawkins
8e166adcbd
Unbreak jaxlib build. (#4098) 2020-08-18 21:24:41 -04:00
George Necula
c7aff1da06
Revert "Use pytree from xla_client. (#4063)" (#4081)
This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f.

Tryting to revert because it seems that this produces test
failures in Google.
2020-08-17 12:53:18 +03:00
Jean-Baptiste Lespiau
d8de6b6141
Use pytree from xla_client. (#4063) 2020-08-14 11:44:03 -04:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
android
ddbdcfb9c9 Add TPU Driver to jaxlib (#1673) 2019-11-12 18:11:39 -08:00
Peter Hawkins
5ac356d680 Add support for batched triangular solve and LU decomposition on GPU using cuBlas. 2019-08-08 13:34:53 -04:00
Peter Hawkins
72047c6eca Update XLA. 2019-08-07 12:55:09 -04:00
Peter Hawkins
ed3e2308c1 Add support for linear algebra ops on GPU using Cusolver:
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.

Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
2019-08-02 11:16:15 -04:00
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
Peter Hawkins
5eff830f0e Move jaxlib version.py into jaxlib, and install it in build/jaxlib as build action.
Update jaxlib version check to look in jaxlib.version.
2019-04-01 08:21:22 -07:00
Peter Hawkins
0316d31479 Rename build/BUILD to build/BUILD.bazel.
Avoids name conflict when building wheels on case-insensitive filesystems as on Mac OS X.
2019-01-13 09:38:10 -05:00