Drop support for NumPy 1.16.

This commit is contained in:
Peter Hawkins 2021-06-10 12:12:13 -04:00
parent 87a533e4ea
commit b130257ee1
6 changed files with 12 additions and 36 deletions

View File

@ -52,7 +52,7 @@ jobs:
os: ubuntu-latest
enable-x64: 1
# Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.16.4 scipy==1.2.1"
package-overrides: "numpy==1.17.5 scipy==1.2.1"
num_generated_cases: 10
- name-prefix: "with 3.7"
python-version: 3.7

View File

@ -13,6 +13,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:
* Breaking changes:
* Support for NumPy 1.16 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
* Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back:
@ -33,8 +35,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.
* Breaking changes:
* Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations

View File

@ -78,8 +78,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split('.')[:2]))
if numpy_version < (1, 16):
print("ERROR: JAX requires NumPy 1.16 or newer, found " + version + ".")
if numpy_version < (1, 17):
print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
sys.exit(-1)
return version

View File

@ -32,7 +32,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
python_requires='>=3.6',
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={

View File

@ -34,7 +34,7 @@ setup(
package_data={'jax': ['py.typed']},
python_requires='>=3.6',
install_requires=[
'numpy >=1.12',
'numpy>=1.17',
'absl-py',
'opt_einsum',
],

View File

@ -50,7 +50,7 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = tuple(map(int, np.version.version.split('.')))
numpy_version = tuple(map(int, np.__version__.split('.')))
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
@ -140,6 +140,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
jtu.rand_default, [], check_dtypes=False),
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
@ -198,13 +200,6 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
inexact=True, tolerance={np.float64: 1e-9}),
]
# Skip np.i0() tests on older numpy: https://github.com/numpy/numpy/issues/11205
if numpy_version >= (1, 17, 0):
JAX_ONE_TO_ONE_OP_RECORDS.append(
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
)
JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
@ -821,7 +816,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
@ -1567,9 +1561,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 16, 6),
"numpy <= 1.16.5 has a bug in linear_ramp")
# https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
@ -1618,7 +1609,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17, 0), "empty mode is new in numpy 1.17.0")
def testPadEmpty(self):
arr = np.arange(6).reshape(2, 3)
@ -1670,7 +1660,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(NotImplementedError, match):
jnp.pad(arr, pad_width, mode)
@unittest.skipIf(numpy_version < (1, 17, 0), "function mode is new in numpy 1.17.0")
def testPadFunction(self):
def np_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
@ -2894,7 +2883,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaises(TypeError):
jnp.ones((-1, 1))
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
jtu.format_shape_dtype_string(shape, in_dtype),
@ -2921,7 +2909,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype),
@ -2944,7 +2931,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype),
@ -3949,8 +3935,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]))
def testPackbits(self, shape, dtype, axis, bitorder):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
@ -3969,8 +3953,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for axis in [None, 0, 1, -2, -1]
for count in [None, 20]))
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = jtu.rand_int(self.rng(), 0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
@ -4144,14 +4126,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for sparse in [True, False]))
def testIndices(self, dimensions, dtype, sparse):
def args_maker(): return []
if numpy_version < (1, 17):
if sparse:
raise SkipTest("indices does not have sparse on numpy < 1.17")
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype)
else:
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
jnp_fun = partial(jnp.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)