114 Commits

Author SHA1 Message Date
Peter Hawkins
3c4527b6b0 Check build and wheel are installed before building jaxlib. 2023-07-26 11:46:11 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
Peter Hawkins
1d4b10b775 Remove --distinct_host_configuration from Bazel flags.
This flag does nothing under Bazel 6 and will be removed in Bazel 7.
2023-07-11 11:38:05 -04:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Peter Hawkins
119661ce6b Remove older plugin device integration.
Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
2023-06-14 15:26:58 -04:00
Peter Hawkins
a18e82b28b Update bazel version to 6.1.2.
Several of our CI builds are already using 6.1.2, so it's probably best to upgrade for consistency.
2023-05-10 10:57:29 -04:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
jax authors
6715736583 Merge pull request #15205 from yhtang:editable-jaxlib-build
PiperOrigin-RevId: 519704474
2023-03-27 06:33:31 -07:00
Yu-Hang 'Maxin' Tang
caaa0a2669 add build option to create editable jaxlib
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2023-03-24 21:25:26 +00:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake
21f82c6c0d Use the pjrt plugin device client. 2022-08-17 14:34:07 +01:00
Peter Hawkins
03876bd702 build.py fixes.
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue https://github.com/google/jax/issues/7097
Fixes https://github.com/google/jax/issues/7639
2022-08-16 15:47:15 +00:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
9fb9e12169 Don't include PTX for older GPU generations.
See: https://github.com/tensorflow/tensorflow/pull/55613

For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB.

Don't redundantly specify default compute capabilities in .bazelrc and in
build.py.
2022-05-02 20:27:37 +00:00
Yash Katariya
6ba9fb699d Upgrade the bazel version to 5.1.1
PiperOrigin-RevId: 441338363
2022-04-12 17:48:09 -07:00
Yash Katariya
aa5d6b4a58 Fix the breakage by including --experimental_cc_shared_library as done by TF.
PiperOrigin-RevId: 438746867
2022-03-31 23:07:42 -07:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Peter Hawkins
2e0cfe8e42 Update the list of default CUDA capabilities used for wheel builds to match build.py. 2022-02-15 09:23:28 -05:00
Peter Hawkins
2388e353da Increase bazel version to 5.0.0 to match TensorFlow
(8871926b0a).
2022-01-28 21:11:02 +00:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
Peter Hawkins
66823d1392 Include compute capability 8.0 SASS in jaxlib wheels.
Drop compute capability 6.1 to avoid growing the wheel size.

Also fix an unrelated build error due to a gcc warning in boringssl.
2021-12-14 14:27:19 -05:00
Peter Hawkins
ffb7ec1651 Update Bazel to 4.2.1.
Fixes #8573
2021-12-02 09:11:38 -05:00
Peter Hawkins
9212d5c83b Print the bazel version from build.py.
Increment the minimum version that build.py checks for to 3.7.2.
2021-11-03 12:12:54 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
brett koonce
9c5009efd5 tweak cuda/rocm targets
Closes #7955.
2021-09-18 14:29:13 -05:00
yashkatariya
d0acd9f343 Add flags to configure the cuda_compute_capability and rocm_amd_targets 2021-09-17 08:43:25 -07:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00
Yash Katariya
7dca05066a Remove build --strategy=Genrule=standalone since this makes genrule run locally instead of on remote RBE clusters leading to errors that are very hard to debug.
Add this back in build.py because that script is used for building jaxlib locally.

PiperOrigin-RevId: 391603481
2021-08-18 13:56:43 -07:00
Reza Rahimi
f454f6b7b8 fix rocm_amdgpu_targets for rocm 2021-08-17 22:13:29 +00:00
Yash Katariya
2d1854a8ba Create a .bazelrc file that is the base for all the builds. The current build.py workflow will not be affected since this .bazelrc file will be overridden. I am going to change that workflow in the coming CLs.
PiperOrigin-RevId: 390003558
2021-08-10 16:17:58 -07:00
jax authors
1a2ddc055d Merge pull request #7419 from cloudhan:win-builder-prepare
PiperOrigin-RevId: 389938513
2021-08-10 11:42:48 -07:00
yashkatariya
b24574f743 Fix the regex 2021-08-05 13:25:06 -07:00
yashkatariya
86aaf80dce Use bazel --version 2021-08-05 09:45:39 -07:00
Cloud Han
be1705306c fix bazel auto download 2021-07-31 19:13:16 +08:00
Cloud Han
d4349a42b5 workaround weird issue due to non-standard cuda path 2021-07-30 19:25:33 +08:00
Cloud Han
86e37d996a some fs support now can be disabled without compile error 2021-07-30 19:23:18 +08:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
Peter Hawkins
6c08702489 Add support for ppc64le cross-compilation on Linux.
Use Bazel 4.1.0 unconditionally on all platforms.
2021-07-23 10:39:02 -04:00
Cloud Han
2d321c26e6 Use TF_CUDA_PATHS
CUDA_TOOLKIT_PATH and CUDNN_INSTALL_PATH are deprecated, see TF 2.0
release notes for more information
2021-07-18 22:55:34 +08:00
jax authors
6aa20d8f8f Merge pull request #7294 from hawkinsp:py36
PiperOrigin-RevId: 384994957
2021-07-15 13:19:23 -07:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00