mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021. The next NumPy deprecation will be 1.19 on Jun 21, 2022. PiperOrigin-RevId: 419651428
This commit is contained in:
parent
eaf7885460
commit
04369a3588
@ -13,6 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.26...main).
|
||||
|
||||
* Breaking changes:
|
||||
* Support for NumPy 1.18 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported NumPy version.
|
||||
* The host_callback primitives have been simplified to drop the
|
||||
special autodiff handling for hcb.id_tap and id_print.
|
||||
From now on, only the primals are tapped. The old behavior can be
|
||||
@ -36,6 +39,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
|
||||
to increase the number of compute capabilities: GPUs with compute capability
|
||||
6.1 can use the 6.0 SASS.
|
||||
* Breaking changes
|
||||
* Support for NumPy 1.18 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported NumPy version.
|
||||
|
||||
## jaxlib 0.1.75 (Dec 8, 2021)
|
||||
* New features:
|
||||
|
@ -13,9 +13,9 @@ RUN /pyenv/bin/pyenv install 3.8.0
|
||||
RUN /pyenv/bin/pyenv install 3.9.0
|
||||
|
||||
# We pin numpy to the minimum permitted version to avoid compatibility issues.
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.18.5 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.18.5 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.19.5 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.19.5 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.5 setuptools wheel six auditwheel
|
||||
|
||||
# Change the CUDA version if it doesn't match the installed version.
|
||||
ARG JAX_CUDA_VERSION=10.0
|
||||
|
@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
|
||||
version = shell(
|
||||
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
||||
numpy_version = tuple(map(int, version.split(".")[:2]))
|
||||
if numpy_version < (1, 18):
|
||||
print("ERROR: JAX requires NumPy 1.18 or newer, found " + version + ".")
|
||||
if numpy_version < (1, 19):
|
||||
print("ERROR: JAX requires NumPy 1.19 or newer, found " + version + ".")
|
||||
sys.exit(-1)
|
||||
return version
|
||||
|
||||
|
@ -44,6 +44,6 @@ build_jax () {
|
||||
|
||||
|
||||
rm -fr dist
|
||||
build_jax 3.7.2 1.18.5
|
||||
build_jax 3.8.0 1.18.5
|
||||
build_jax 3.9.0 1.19.4
|
||||
build_jax 3.7.2 1.19.5
|
||||
build_jax 3.8.0 1.19.5
|
||||
build_jax 3.9.0 1.19.5
|
||||
|
@ -33,7 +33,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.7',
|
||||
install_requires=['scipy', 'numpy>=1.18', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
install_requires=['scipy', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
package_data={
|
||||
|
Loading…
x
Reference in New Issue
Block a user