Increase minimum NumPy version to 1.20.

Per NEP 29, support for 1.19 ended on Jun 21, 2022.
This commit is contained in:
Peter Hawkins 2022-08-06 14:49:09 +00:00
parent c02359b924
commit c735c6bf0e
6 changed files with 14 additions and 49 deletions

View File

@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.16 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
* Breaking changes
* Support for NumPy 1.19 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to NumPy 1.20 or newer.
* Changes
* Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`.
* Added new documentation for [runtime value debugging](debugging/index)

View File

@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split(".")[:2]))
if numpy_version < (1, 19):
print("ERROR: JAX requires NumPy 1.19 or newer, found " + version + ".")
if numpy_version < (1, 20):
print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".")
sys.exit(-1)
return version

View File

@ -43,7 +43,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.7',
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py'],
install_requires=['scipy>=1.5', 'numpy>=1.20', 'absl-py'],
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[

View File

@ -65,7 +65,7 @@ setup(
python_requires='>=3.7',
install_requires=[
'absl-py',
'numpy>=1.19',
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',
'typing_extensions',

View File

@ -30,11 +30,7 @@ from jax._src.numpy.util import _promote_dtypes_complex
from jax.config import config
config.parse_flags_with_absl()
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
if numpy_version < (1, 20):
FFT_NORMS = [None, "ortho"]
else:
FFT_NORMS = [None, "ortho", "forward", "backward"]
FFT_NORMS = [None, "ortho", "forward", "backward"]
float_dtypes = jtu.dtypes.floating

View File

@ -1012,7 +1012,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
@ -1912,8 +1911,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# following types lack precision
dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])))
def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
if numpy_version < (1, 20) and np.issubdtype(dtype, np.integer):
raise unittest.SkipTest("NumPy 1.20 changed the semantics of np.linspace")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -2316,9 +2313,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(*args):
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
for x in args]
if numpy_version < (1, 20):
_dtype = dtype or jnp.result_type(*arg_dtypes)
return np.concatenate(args, axis=axis).astype(_dtype)
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)
@ -2425,16 +2419,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
# Previous to numpy 1.19, negative indices were ignored so we don't test this.
low = 0 if numpy_version < (1, 19, 0) else -max_idx
idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int)
idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
@ -2606,8 +2597,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for return_inverse in [False, True]
for return_counts in [False, True]))
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0:
self.skipTest("zero-sized axis in unique leads to error in older numpy.")
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
extra_args = (return_index, return_inverse, return_counts)
@ -3462,8 +3451,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for fill_value_shape in s(_compatible_shapes(shape if out_shape is None else out_shape))
for out_dtype in s(default_dtypes))))
def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
np_fun = lambda x, fill_value: np.full_like(
x, fill_value, dtype=out_dtype, shape=out_shape)
@ -3485,8 +3472,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for func in ["ones_like", "zeros_like"]
for out_dtype in default_dtypes))
def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape)
jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape)
@ -3509,8 +3494,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())]
for out_dtype in [None, float]))
def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
x = lax_internal._convert_element_type(rng(shape, in_dtype),
weak_type=weak_type)
@ -3806,8 +3789,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.resize(x, out_shape)
jnp_fun = lambda x: jnp.resize(x, out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
if len(out_shape) > 0 or numpy_version >= (1, 20, 0):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -5757,20 +5739,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_op = lambda start, stop: jnp.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
# NumPy 1.20.0 changed the semantics of linspace to floor for integer
# dtypes.
if numpy_version >= (1, 20) or not np.issubdtype(dtype, np.integer):
np_op = lambda start, stop: np.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
else:
def np_op(start, stop):
out = np.linspace(start, stop, num, endpoint=endpoint,
retstep=retstep, axis=axis)
if retstep:
return np.floor(out[0]).astype(dtype), out[1]
else:
return np.floor(out).astype(dtype)
np_op = lambda start, stop: np.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
@ -6448,12 +6419,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
mismatches = {}
for name, (jnp_fun, np_fun) in func_pairs.items():
# broadcast_shapes is not available in numpy < 1.20
if numpy_version < (1, 20) and name == "broadcast_shapes":
continue
# Some signatures have changed; skip for older numpy versions.
if numpy_version < (1, 19) and name in ['einsum_path', 'gradient', 'isscalar']:
continue
if numpy_version < (1, 22) and name in ['quantile', 'nanquantile',
'percentile', 'nanpercentile']:
continue