mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
This commit is contained in:
parent
c02359b924
commit
c735c6bf0e
@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.3.16 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
||||
* Breaking changes
|
||||
* Support for NumPy 1.19 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to NumPy 1.20 or newer.
|
||||
* Changes
|
||||
* Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`.
|
||||
* Added new documentation for [runtime value debugging](debugging/index)
|
||||
|
@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
|
||||
version = shell(
|
||||
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
||||
numpy_version = tuple(map(int, version.split(".")[:2]))
|
||||
if numpy_version < (1, 19):
|
||||
print("ERROR: JAX requires NumPy 1.19 or newer, found " + version + ".")
|
||||
if numpy_version < (1, 20):
|
||||
print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".")
|
||||
sys.exit(-1)
|
||||
return version
|
||||
|
||||
|
@ -43,7 +43,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.7',
|
||||
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py'],
|
||||
install_requires=['scipy>=1.5', 'numpy>=1.20', 'absl-py'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
2
setup.py
2
setup.py
@ -65,7 +65,7 @@ setup(
|
||||
python_requires='>=3.7',
|
||||
install_requires=[
|
||||
'absl-py',
|
||||
'numpy>=1.19',
|
||||
'numpy>=1.20',
|
||||
'opt_einsum',
|
||||
'scipy>=1.5',
|
||||
'typing_extensions',
|
||||
|
@ -30,11 +30,7 @@ from jax._src.numpy.util import _promote_dtypes_complex
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
if numpy_version < (1, 20):
|
||||
FFT_NORMS = [None, "ortho"]
|
||||
else:
|
||||
FFT_NORMS = [None, "ortho", "forward", "backward"]
|
||||
FFT_NORMS = [None, "ortho", "forward", "backward"]
|
||||
|
||||
|
||||
float_dtypes = jtu.dtypes.floating
|
||||
|
@ -1012,7 +1012,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy")
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
|
||||
@ -1912,8 +1911,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# following types lack precision
|
||||
dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])))
|
||||
def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
|
||||
if numpy_version < (1, 20) and np.issubdtype(dtype, np.integer):
|
||||
raise unittest.SkipTest("NumPy 1.20 changed the semantics of np.linspace")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
@ -2316,9 +2313,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(*args):
|
||||
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
|
||||
for x in args]
|
||||
if numpy_version < (1, 20):
|
||||
_dtype = dtype or jnp.result_type(*arg_dtypes)
|
||||
return np.concatenate(args, axis=axis).astype(_dtype)
|
||||
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
|
||||
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)
|
||||
|
||||
@ -2425,16 +2419,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
|
||||
# Previous to numpy 1.19, negative indices were ignored so we don't test this.
|
||||
low = 0 if numpy_version < (1, 19, 0) else -max_idx
|
||||
idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int)
|
||||
idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
|
||||
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
@ -2606,8 +2597,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for return_inverse in [False, True]
|
||||
for return_counts in [False, True]))
|
||||
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
|
||||
if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0:
|
||||
self.skipTest("zero-sized axis in unique leads to error in older numpy.")
|
||||
rng = jtu.rand_some_equal(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
extra_args = (return_index, return_inverse, return_counts)
|
||||
@ -3462,8 +3451,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for fill_value_shape in s(_compatible_shapes(shape if out_shape is None else out_shape))
|
||||
for out_dtype in s(default_dtypes))))
|
||||
def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape):
|
||||
if numpy_version < (1, 19) and out_shape == ():
|
||||
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_fun = lambda x, fill_value: np.full_like(
|
||||
x, fill_value, dtype=out_dtype, shape=out_shape)
|
||||
@ -3485,8 +3472,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for func in ["ones_like", "zeros_like"]
|
||||
for out_dtype in default_dtypes))
|
||||
def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
|
||||
if numpy_version < (1, 19) and out_shape == ():
|
||||
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape)
|
||||
jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape)
|
||||
@ -3509,8 +3494,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())]
|
||||
for out_dtype in [None, float]))
|
||||
def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype):
|
||||
if numpy_version < (1, 19) and out_shape == ():
|
||||
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = lax_internal._convert_element_type(rng(shape, in_dtype),
|
||||
weak_type=weak_type)
|
||||
@ -3806,8 +3789,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.resize(x, out_shape)
|
||||
jnp_fun = lambda x: jnp.resize(x, out_shape)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
if len(out_shape) > 0 or numpy_version >= (1, 20, 0):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -5757,20 +5739,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_op = lambda start, stop: jnp.linspace(
|
||||
start, stop, num,
|
||||
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
|
||||
# NumPy 1.20.0 changed the semantics of linspace to floor for integer
|
||||
# dtypes.
|
||||
if numpy_version >= (1, 20) or not np.issubdtype(dtype, np.integer):
|
||||
np_op = lambda start, stop: np.linspace(
|
||||
start, stop, num,
|
||||
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
|
||||
else:
|
||||
def np_op(start, stop):
|
||||
out = np.linspace(start, stop, num, endpoint=endpoint,
|
||||
retstep=retstep, axis=axis)
|
||||
if retstep:
|
||||
return np.floor(out[0]).astype(dtype), out[1]
|
||||
else:
|
||||
return np.floor(out).astype(dtype)
|
||||
np_op = lambda start, stop: np.linspace(
|
||||
start, stop, num,
|
||||
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
|
||||
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
|
||||
check_dtypes=False, tol=tol)
|
||||
@ -6448,12 +6419,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
mismatches = {}
|
||||
|
||||
for name, (jnp_fun, np_fun) in func_pairs.items():
|
||||
# broadcast_shapes is not available in numpy < 1.20
|
||||
if numpy_version < (1, 20) and name == "broadcast_shapes":
|
||||
continue
|
||||
# Some signatures have changed; skip for older numpy versions.
|
||||
if numpy_version < (1, 19) and name in ['einsum_path', 'gradient', 'isscalar']:
|
||||
continue
|
||||
if numpy_version < (1, 22) and name in ['quantile', 'nanquantile',
|
||||
'percentile', 'nanpercentile']:
|
||||
continue
|
||||
|
Loading…
x
Reference in New Issue
Block a user