1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Drop support for NumPy 1.17.

This commit is contained in:
Peter Hawkins 2021-07-29 09:18:01 -04:00
parent c4602475ca
commit 6e9169d100
6 changed files with 13 additions and 10 deletions

@ -62,7 +62,7 @@ jobs:
os: ubuntu-latest
enable-x64: 0
# Test with the oldest legal NumPy version.
package-overrides: "numpy==1.17.5 scipy==1.2.1"
package-overrides: "numpy==1.18.5 scipy==1.2.1"
num_generated_cases: 8
# Test against latest jaxlib
use-latest-jaxlib: true

@ -10,7 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.19 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.18...main).
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* New features:
* Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., `jnp.mean`.
@ -21,7 +24,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called

@ -82,8 +82,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split('.')[:2]))
if numpy_version < (1, 17):
print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
if numpy_version < (1, 18):
print("ERROR: JAX requires NumPy 1.18 or newer, found " + version + ".")
sys.exit(-1)
return version

@ -32,7 +32,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
python_requires='>=3.7',
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
install_requires=['scipy', 'numpy>=1.18', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={

@ -36,7 +36,7 @@ setup(
package_data={'jax': ['py.typed']},
python_requires='>=3.7',
install_requires=[
'numpy>=1.17',
'numpy>=1.18',
'absl-py',
'opt_einsum',
],

@ -1406,7 +1406,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
args_maker = lambda: [rng(shape, dtype)]
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@ -4226,8 +4225,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
else:
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
# TODO(jakevdp): remove this ignore_warning when minimum numpy version is 1.17.0
@jtu.ignore_warning(category=RuntimeWarning, message="Invalid value encountered.*")
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]