Drop support for NumPy 1.18.

Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
This commit is contained in:
Peter Hawkins 2022-01-04 12:11:00 -08:00 committed by jax authors
parent eaf7885460
commit 04369a3588
6 changed files with 17 additions and 10 deletions

View File

@ -13,6 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
commits](https://github.com/google/jax/compare/jax-v0.2.26...main).
* Breaking changes:
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
@ -36,6 +39,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS.
* Breaking changes
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
## jaxlib 0.1.75 (Dec 8, 2021)
* New features:

View File

@ -13,9 +13,9 @@ RUN /pyenv/bin/pyenv install 3.8.0
RUN /pyenv/bin/pyenv install 3.9.0
# We pin numpy to the minimum permitted version to avoid compatibility issues.
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.18.5 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.18.5 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.19.5 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.19.5 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.5 setuptools wheel six auditwheel
# Change the CUDA version if it doesn't match the installed version.
ARG JAX_CUDA_VERSION=10.0

View File

@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split(".")[:2]))
if numpy_version < (1, 18):
print("ERROR: JAX requires NumPy 1.18 or newer, found " + version + ".")
if numpy_version < (1, 19):
print("ERROR: JAX requires NumPy 1.19 or newer, found " + version + ".")
sys.exit(-1)
return version

View File

@ -44,6 +44,6 @@ build_jax () {
rm -fr dist
build_jax 3.7.2 1.18.5
build_jax 3.8.0 1.18.5
build_jax 3.9.0 1.19.4
build_jax 3.7.2 1.19.5
build_jax 3.8.0 1.19.5
build_jax 3.9.0 1.19.5

View File

@ -33,7 +33,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.7',
install_requires=['scipy', 'numpy>=1.18', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
install_requires=['scipy', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={

View File

@ -39,7 +39,7 @@ setup(
python_requires='>=3.7',
install_requires=[
'absl-py',
'numpy>=1.18',
'numpy>=1.19',
'opt_einsum',
'scipy>=1.2.1',
'typing_extensions',