diff --git a/CHANGELOG.md b/CHANGELOG.md index 16b1e6be4..f5c3929ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Remember to align the itemized text with the first line of an item within a list like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects. + * JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer. * Deprecations * The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead, diff --git a/build/build.py b/build/build.py index aaa71e56d..0ac3a58e5 100755 --- a/build/build.py +++ b/build/build.py @@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path): version = shell( [python_bin_path, "-c", "import numpy as np; print(np.__version__)"]) numpy_version = tuple(map(int, version.split(".")[:2])) - if numpy_version < (1, 20): - print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".") + if numpy_version < (1, 21): + print("ERROR: JAX requires NumPy 1.21 or newer, found " + version + ".") sys.exit(-1) return version diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 1c607dfd3..14a0d83eb 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -46,7 +46,7 @@ setup( author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.8', - install_requires=['scipy>=1.5', 'numpy>=1.20', 'ml_dtypes>=0.0.3'], + install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'], url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ diff --git a/setup.py b/setup.py index 2c8b186e5..466603eda 100644 --- a/setup.py +++ b/setup.py @@ -65,9 +65,9 @@ setup( python_requires='>=3.8', install_requires=[ 'ml_dtypes>=0.0.3', - 'numpy>=1.20', + 'numpy>=1.21', 'opt_einsum', - 'scipy>=1.5', + 'scipy>=1.7', ], extras_require={ # Minimum jaxlib version; used in testing. diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index ee38560d1..bb0d758d8 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -492,8 +492,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] - if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) def testReductionOfOutOfBoundsAxis(self): # Issue 888 diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index fb7a5af9a..8ef546d0a 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1322,8 +1322,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def testMode(self, shape, dtype, axis, contains_nans, keepdims): if scipy_version < (1, 9, 0) and keepdims != True: self.skipTest("scipy < 1.9.0 only support keepdims == True") - if numpy_version < (1, 21, 0) and contains_nans: - self.skipTest("numpy < 1.21.0 only support contains_nans == False") if contains_nans: rng = jtu.rand_some_nan(self.rng())