547 Commits

Author SHA1 Message Date
brett koonce
9c5009efd5 tweak cuda/rocm targets
Closes #7955.
2021-09-18 14:29:13 -05:00
yashkatariya
d0acd9f343 Add flags to configure the cuda_compute_capability and rocm_amd_targets 2021-09-17 08:43:25 -07:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00
Peter Hawkins
f004bcb7b8 [JAX] Refactor JAX custom kernels to split kernel implementations from Python bindings.
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.

The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.

There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.

PiperOrigin-RevId: 394460343
2021-09-02 07:53:09 -07:00
Yash Katariya
a08a9ad42f Delete the generate_release_indexes file.
PiperOrigin-RevId: 394081682
2021-08-31 14:07:07 -07:00
Jake VanderPlas
a5b6a4e6a9 CI: remove flake8 from test requirements. 2021-08-25 11:07:09 -07:00
dependabot[bot]
9f2863c66b Copybara import of the project:
--
57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f by dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>:

Bump flatbuffers from 1.12 to 2.0

Bumps [flatbuffers](https://github.com/google/flatbuffers) from 1.12 to 2.0.
- [Release notes](https://github.com/google/flatbuffers/releases)
- [Commits](https://github.com/google/flatbuffers/compare/v1.12.0...v2.0.0)

---
updated-dependencies:
- dependency-name: flatbuffers
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7686 from google:dependabot/pip/flatbuffers-2.0 57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f
PiperOrigin-RevId: 392097862
2021-08-20 17:13:26 -07:00
Jake VanderPlas
cbcd6eeadb CI: bump mypy & flake8 versions to newest 2021-08-20 14:35:37 -07:00
Jake VanderPlas
7fa151c5c3 cleanup: remove redundant entry from test-requirements 2021-08-20 10:09:14 -07:00
Yash Katariya
7dca05066a Remove build --strategy=Genrule=standalone since this makes genrule run locally instead of on remote RBE clusters leading to errors that are very hard to debug.
Add this back in build.py because that script is used for building jaxlib locally.

PiperOrigin-RevId: 391603481
2021-08-18 13:56:43 -07:00
Reza Rahimi
f454f6b7b8 fix rocm_amdgpu_targets for rocm 2021-08-17 22:13:29 +00:00
yashkatariya
3a99d18df8 Add nocuda jaxlib wheels to the index as well 2021-08-13 09:03:08 -07:00
Yash Katariya
2d1854a8ba Create a .bazelrc file that is the base for all the builds. The current build.py workflow will not be affected since this .bazelrc file will be overridden. I am going to change that workflow in the coming CLs.
PiperOrigin-RevId: 390003558
2021-08-10 16:17:58 -07:00
jax authors
1a2ddc055d Merge pull request #7419 from cloudhan:win-builder-prepare
PiperOrigin-RevId: 389938513
2021-08-10 11:42:48 -07:00
yashkatariya
b24574f743 Fix the regex 2021-08-05 13:25:06 -07:00
yashkatariya
86aaf80dce Use bazel --version 2021-08-05 09:45:39 -07:00
Roy Frostig
6984f30d5e Merge pull request #7443 from superbobry:jaxlib-xla-extension
PiperOrigin-RevId: 388235730
2021-08-02 12:40:18 -07:00
Sergei Lebedev
2a994bdb02 Type stubs for jaxlib.xla_extension no longer use -stubs suffix
PEP-561 does not specify whether subpackages of a non-stub-only-package
could use the -stubs suffix. setuptools seems to allow that, yet mypy fails
to resolve the subpackage with a -stubs suffix.

This commit makes jaxlib.xla_extension a ~normal package with a toplevel
__init__.pyi.
2021-08-02 14:31:11 +01:00
Peter Hawkins
cde2612893 Update minimum NumPy version in jaxlib build scripts. 2021-08-02 09:22:32 -04:00
Cloud Han
be1705306c fix bazel auto download 2021-07-31 19:13:16 +08:00
Cloud Han
d4349a42b5 workaround weird issue due to non-standard cuda path 2021-07-30 19:25:33 +08:00
Cloud Han
86e37d996a some fs support now can be disabled without compile error 2021-07-30 19:23:18 +08:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
Peter Hawkins
6c08702489 Add support for ppc64le cross-compilation on Linux.
Use Bazel 4.1.0 unconditionally on all platforms.
2021-07-23 10:39:02 -04:00
Cloud Han
2d321c26e6 Use TF_CUDA_PATHS
CUDA_TOOLKIT_PATH and CUDNN_INSTALL_PATH are deprecated, see TF 2.0
release notes for more information
2021-07-18 22:55:34 +08:00
Cloud Han
4fa79ce1cb fix machine tag, on windows platforms.machine() returns AMD64 instread of x64_64 2021-07-18 22:55:02 +08:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
jax authors
6aa20d8f8f Merge pull request #7294 from hawkinsp:py36
PiperOrigin-RevId: 384994957
2021-07-15 13:19:23 -07:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Peter Hawkins
f5c61a892a Add support for cross-compiling jaxlib for Mac ARM. 2021-07-15 10:37:53 -04:00
Peter Hawkins
7d2aec105f Add an option to disable NCCL. 2021-07-13 09:10:29 -04:00
jax authors
10569871b7 Merge pull request #7260 from skye:release_indexes
PiperOrigin-RevId: 384341784
2021-07-12 16:27:47 -07:00
Skye Wanderman-Milne
1a650d2e50 Update generate_release_index[es].py to also produce libtpu_releases.html.
Previously, the libtpu-nightly wheels were included in the same index
file as the jaxlib wheels (jax_releases.html). This caused issues
because it would cause `pip install jax[tpu] -f jaxlib_releases.html`
to install a cuda jaxlib, instead of the regular CPU/TPU jaxlib from
pypi.

Instead, we create a separate index file for the libtpu-nightly
wheels, so `pip install jax[tpu] -f libtpu_releases.html` still uses
the jaxlib from pypi.

This also renames generate_release_index.py to generate_release_indexes.py.
2021-07-12 16:02:54 -07:00
Peter Hawkins
0de4a60834 Update pillow pin to >= 8.3.1.
8.3.1 fixed the issue from https://github.com/google/jax/pull/7166.
2021-07-07 08:33:29 -04:00
Jake VanderPlas
4ba343aa83 CI: pin pillow dependency to 8.2 to avoid failures under 8.3 2021-07-01 16:32:35 -07:00
Skye Wanderman-Milne
55276d15e4 Fix pip install jax[tpu]
* Updates jax_releases.html index to include libtpu wheels
* Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL

The full install command will now be:
`pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html`
(similar to the cuda install commands)

I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
2021-06-23 14:13:15 -07:00
Qiao Zhang
7c54b44dba Bump numpy versions for macos build. 2021-06-23 10:23:43 -07:00
jax authors
6ae2f6d5a4 Merge pull request #7060 from zhangqiaorjc:npv
PiperOrigin-RevId: 380918414
2021-06-22 16:37:50 -07:00
Qiao Zhang
f3c8a22d66 Bump np versions for build script. 2021-06-22 16:29:56 -07:00
Qiao Zhang
18c7610e96 Catch CalledProcessError in build script. 2021-06-22 13:36:14 -07:00
Jake VanderPlas
0c91be7b46 CI: temporarily pin numpy to <1.21 2021-06-22 11:15:16 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
07277f0785 Bump mypy version to 0.902. 2021-06-14 10:05:34 -04:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
akbir khan
dc81610bb6 updated to official bazel.4.1.0 2021-05-24 21:40:08 +01:00
Peter Hawkins
dacf31f202 Check for NumPy and SciPy versions during jaxlib builds. 2021-05-24 12:39:37 -04:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Peter Hawkins
86887a21e2 Add Mac M1 support to build.py. 2021-05-18 21:53:31 -04:00
Peter Hawkins
40c5e376d8 Pin flatbuffers 1.12 for CI tests. 2021-05-10 18:21:25 -04:00
jax authors
010c383ab3 Merge pull request #6694 from hawkinsp:wheel2
PiperOrigin-RevId: 372963990
2021-05-10 10:49:17 -07:00