181 Commits

Author SHA1 Message Date
erwin coumans
f0e55e3ce2 move #endif so that Windows doesn't have GetXCR0EAX defined twice
(erroring out)
2021-06-04 17:26:27 -07:00
Peter Hawkins
7a099bf7ee Mark PyInit_cpu_feature_guard symbol as exported.
Otherwise it may get stripped out by -fvisibility=hidden.

PiperOrigin-RevId: 374739406
2021-05-19 15:12:00 -07:00
Peter Hawkins
b03466a390 Remove jax/BUILD file.
This Bazel build file is unused; we only use Bazel to build jaxlib, which does not include any files from jax/.

PiperOrigin-RevId: 374709682
2021-05-19 12:55:34 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Skye Wanderman-Milne
63dbb99a66 Update README, etc. for jaxlib 0.1.67 release 2021-05-17 17:48:46 -07:00
Qiao Zhang
528d5bbb11 Update README etc for jaxlib 0.1.66 release. 2021-05-11 16:49:32 -07:00
Peter Hawkins
01d6e32c7f Add version constraints to flatbuffers versions.
Require 1.12 or newer, because we've only tested 1.12 and 2.0.
Require less than 3.0, because flatbuffers uses semantic versioning and version 3.0 would mean an incompatible change has been made.
2021-05-10 20:58:22 -04:00
Qiao Zhang
1509f995ed Update flatbuffer ver2 EndVector usage.
PiperOrigin-RevId: 373039656
2021-05-10 16:45:19 -07:00
Qiao Zhang
fa6211af8d Gate on CUSPARSE_VERSION instead of CUDART_VERSION in jaxlib/cusparse.cc.
PiperOrigin-RevId: 373032029
2021-05-10 16:04:31 -07:00
Peter Hawkins
c983d3c660 Bundle libdevice.10.bc with jaxlib wheels.
libdevice.10.bc is a redistributable part of the CUDA SDK.

This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
2021-04-29 10:26:03 -04:00
Jake VanderPlas
122fbcbd09 cusparse: use cstdint types
PiperOrigin-RevId: 369739828
2021-04-21 14:48:31 -07:00
Jake VanderPlas
e152d6645d Fix cusparse headers
PiperOrigin-RevId: 369287633
2021-04-19 13:17:03 -07:00
Jake VanderPlas
919b11e81a Remove unnecessary dependency
PiperOrigin-RevId: 368882451
2021-04-16 11:10:33 -07:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
Skye Wanderman-Milne
f8f373466c Update README, etc. for jaxlib 0.1.65 release 2021-04-07 17:51:20 -07:00
jax authors
f51fa64ba5 Merge pull request #6337 from ahoenselaar:changelist/366161141
PiperOrigin-RevId: 367130827
2021-04-06 18:57:54 -07:00
Sergei Lebedev
225ffc30d8 Re-exported tensorflow...xla_extension type stubs in jaxlib
The type stubs allow using precise types for XLA primitives instead
of aliasing them to Any.

This commit does not change any type annotations within JAX. That will
be done in a followup. I have manually verified that type stubs are
discoverable by mypy once the new jaxlib is installed by type "checking"

    from jaxlib import xla_extension as xe
    d: xe._Dtype
2021-04-06 14:51:45 +01:00
Andreas Hoenselaar
a19098d462 Reimplement as JAX Primitive 2021-04-03 14:11:36 -07:00
Jake VanderPlas
f9a4162551 Specify minimum jaxlib version in a single location 2021-03-22 16:14:41 -07:00
Skye Wanderman-Milne
0cbe2c1c05 Update README, etc. for jaxlib 0.1.64 release 2021-03-18 16:11:40 -07:00
Skye Wanderman-Milne
757247b791 Update README, etc. for jaxlib 0.1.63 release 2021-03-17 10:14:52 -07:00
Skye Wanderman-Milne
f06bb9a7f4 Update jaxlib version etc. 2021-03-09 17:55:40 -08:00
Skye Wanderman-Milne
7a67b974ac jaxlib version bump etc. 2021-02-12 09:42:04 -08:00
Peter Hawkins
13f3819054 Update README.md for jaxlib 0.1.60.
Bump jaxlib version to 0.1.61 and update changelog.

Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.

Bump minimum jaxlib version to 0.1.60.
2021-02-03 20:44:01 -05:00
Andreas Hoenselaar
2f6ed3cfea Initialize variables in LAPACK work size queries to prevent false positives in memory sanitizers.
Related prior art in SciPy: https://github.com/scipy/scipy/pull/9054
2021-02-01 16:38:24 -08:00
George Necula
a145e3d414 Pin numpy to max version 1.19, to avoid errors with 1.20
Will fix the numpy errors separately.
2021-01-31 15:18:54 +02:00
Peter Hawkins
9bdc2ecc66 Consolidate build macros into a single jax.bzl file.
PiperOrigin-RevId: 352871429
2021-01-20 14:06:22 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
jax authors
8420ee200f Merge pull request #5472 from hawkinsp:build
PiperOrigin-RevId: 352797345
2021-01-20 08:21:02 -08:00
jax authors
d58563845b Merge pull request #5470 from inailuig:fix-rocm-build
PiperOrigin-RevId: 352783431
2021-01-20 06:46:44 -08:00
Peter Hawkins
3bec4b331d Fix ABSL build dependencies of //jaxlib:handle_pool 2021-01-20 09:38:55 -05:00
Clemens Giuliani
57405b0dc1 fix building on ROCm 2021-01-20 10:50:04 +01:00
Skye Wanderman-Milne
7c2454e969 Update jaxlib version, minimum jaxlib version, readme, and changelog.
Bumping the min jaxlib version to support https://github.com/google/jax/pull/5213.
2021-01-15 12:56:08 -08:00
Skye Wanderman-Milne
62b864a654 Update WORKSPACE and jaxlib version to 0.1.59 2021-01-14 13:38:17 -08:00
Clemens Giuliani
82b3bacd5b fix header guard typo 2021-01-13 15:55:14 +01:00
Clemens Giuliani
ce84bdb8fa fix svd on ROCm 2021-01-13 15:54:21 +01:00
Qiao Zhang
a91fcac6cd Fix header guard typo. 2020-12-23 11:25:23 -08:00
jax authors
a8518769a2 Merge pull request #5115 from inailuig:rocm-gpukernels
PiperOrigin-RevId: 348077827
2020-12-17 14:01:04 -08:00
Clemens Giuliani
4981c53ac1 Add BLAS and LAPACK gpu kernels for ROCm 2020-12-16 16:00:17 +01:00
Clemens Giuliani
c128bdd90c extract the shared handle pool code from cublas and cusolver 2020-12-16 16:00:16 +01:00
Peter Hawkins
21914db261 Export LICENSE.txt file in jaxlib wheels. 2020-12-11 10:20:03 -05:00
Peter Hawkins
c06ead6b04 Change jaxlib build rules to build a wheel, rather than writing output to the source directory. 2020-11-20 11:47:00 -05:00
jax authors
9aa04443e9 Merge pull request #4965 from hawkinsp:build
PiperOrigin-RevId: 343333501
2020-11-19 11:32:00 -08:00
Peter Hawkins
4bb5dca779 Fix build.py to work on Linux once again.
* strip DOS end-of-line characters from build.py for consistency with the rest of the source tree.
* use shutil.copy() instead of shutil.copyfile(). On Unix systems we must preserve execute permissions.
* add code to explicitly delete and recreate the target directory.
* Move build/jaxlib/__init_py to jaxlib/__init__.py and have the script move it into position, so the output directory for the jaxlib is an empty directory that the script creates.
2020-11-19 13:14:34 -05:00
Skye Wanderman-Milne
36799f5007 README etc. updates for new jaxlib release 2020-11-12 13:48:44 -08:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Skye Wanderman-Milne
a082169642 Update README and jaxlib.__version__ for new jaxlib release 2020-10-20 11:46:48 -07:00
Alexey Radul
537427aae2 Copybara import of the project:
--
138105a9ea44e7a8c3ce575a4e51b7ed51518d41 by Skye Wanderman-Milne <skyewm@google.com>:

Update README, CHANGELOG, and jaxlib.__version__ for new jaxlib release

PiperOrigin-RevId: 338063494
2020-10-20 08:25:23 -07:00
jax authors
60141d959f Merge pull request #4580 from skye:jaxlib
PiperOrigin-RevId: 337950285
2020-10-19 15:45:56 -07:00
Peter Hawkins
7f4e115a6a [XLA:Python] Validate shapes in Python bindings to avoid crashes.
[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32.

Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib.

PiperOrigin-RevId: 337367394
2020-10-15 13:10:04 -07:00