Increase minimum NumPy version to 1.21.

Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
This commit is contained in:
Peter Hawkins 2023-02-06 11:32:28 -05:00
parent e9bc7ee866
commit b7375b316b
6 changed files with 7 additions and 9 deletions

View File

@ -24,6 +24,7 @@ Remember to align the itemized text with the first line of an item within a list
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
* JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
* Deprecations
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,

View File

@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split(".")[:2]))
if numpy_version < (1, 20):
print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".")
if numpy_version < (1, 21):
print("ERROR: JAX requires NumPy 1.21 or newer, found " + version + ".")
sys.exit(-1)
return version

View File

@ -46,7 +46,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.8',
install_requires=['scipy>=1.5', 'numpy>=1.20', 'ml_dtypes>=0.0.3'],
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'],
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[

View File

@ -65,9 +65,9 @@ setup(
python_requires='>=3.8',
install_requires=[
'ml_dtypes>=0.0.3',
'numpy>=1.20',
'numpy>=1.21',
'opt_einsum',
'scipy>=1.5',
'scipy>=1.7',
],
extras_require={
# Minimum jaxlib version; used in testing.

View File

@ -492,8 +492,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
def testReductionOfOutOfBoundsAxis(self): # Issue 888

View File

@ -1322,8 +1322,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def testMode(self, shape, dtype, axis, contains_nans, keepdims):
if scipy_version < (1, 9, 0) and keepdims != True:
self.skipTest("scipy < 1.9.0 only support keepdims == True")
if numpy_version < (1, 21, 0) and contains_nans:
self.skipTest("numpy < 1.21.0 only support contains_nans == False")
if contains_nans:
rng = jtu.rand_some_nan(self.rng())