Drop support for NumPy 1.16.

This commit is contained in:
Peter Hawkins 2021-06-10 12:12:13 -04:00
parent 87a533e4ea
commit b130257ee1
6 changed files with 12 additions and 36 deletions

View File

@ -52,7 +52,7 @@ jobs:
os: ubuntu-latest os: ubuntu-latest
enable-x64: 1 enable-x64: 1
# Test with numpy version that matches Google-internal version # Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.16.4 scipy==1.2.1" package-overrides: "numpy==1.17.5 scipy==1.2.1"
num_generated_cases: 10 num_generated_cases: 10
- name-prefix: "with 3.7" - name-prefix: "with 3.7"
python-version: 3.7 python-version: 3.7

View File

@ -13,6 +13,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* New features: * New features:
* Breaking changes: * Breaking changes:
* Support for NumPy 1.16 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
* Bug fixes: * Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back: * Fixed bug that prevented round-tripping from JAX to TF and back:
@ -33,8 +35,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
in TF ops. The code that XLA generates after jax2tf in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA. has the same location information as JAX/XLA.
* Breaking changes:
* Bug fixes: * Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules * The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations for Python scalars and for choosing 32-bit vs. 64-bit computations

View File

@ -78,8 +78,8 @@ def check_numpy_version(python_bin_path):
version = shell( version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"]) [python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split('.')[:2])) numpy_version = tuple(map(int, version.split('.')[:2]))
if numpy_version < (1, 16): if numpy_version < (1, 17):
print("ERROR: JAX requires NumPy 1.16 or newer, found " + version + ".") print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
sys.exit(-1) sys.exit(-1)
return version return version

View File

@ -32,7 +32,7 @@ setup(
author_email='jax-dev@google.com', author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension-stubs'], packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
python_requires='>=3.6', python_requires='>=3.6',
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers >= 1.12, < 3.0'], install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax', url='https://github.com/google/jax',
license='Apache-2.0', license='Apache-2.0',
package_data={ package_data={

View File

@ -34,7 +34,7 @@ setup(
package_data={'jax': ['py.typed']}, package_data={'jax': ['py.typed']},
python_requires='>=3.6', python_requires='>=3.6',
install_requires=[ install_requires=[
'numpy >=1.12', 'numpy>=1.17',
'absl-py', 'absl-py',
'opt_einsum', 'opt_einsum',
], ],

View File

@ -50,7 +50,7 @@ from jax.config import config
config.parse_flags_with_absl() config.parse_flags_with_absl()
FLAGS = config.FLAGS FLAGS = config.FLAGS
numpy_version = tuple(map(int, np.version.version.split('.'))) numpy_version = tuple(map(int, np.__version__.split('.')))
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
@ -140,6 +140,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
jtu.rand_default, [], check_dtypes=False), jtu.rand_default, [], check_dtypes=False),
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
@ -198,13 +200,6 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
inexact=True, tolerance={np.float64: 1e-9}), inexact=True, tolerance={np.float64: 1e-9}),
] ]
# Skip np.i0() tests on older numpy: https://github.com/numpy/numpy/issues/11205
if numpy_version >= (1, 17, 0):
JAX_ONE_TO_ONE_OP_RECORDS.append(
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
)
JAX_COMPOUND_OP_RECORDS = [ JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions. # angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
@ -821,7 +816,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable( @parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list( jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
@ -1567,9 +1561,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol={np.float32: 1e-3, np.complex64: 1e-3}) tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 16, 6),
"numpy <= 1.16.5 has a bug in linear_ramp")
# https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format( {"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values), jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
@ -1618,7 +1609,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17, 0), "empty mode is new in numpy 1.17.0")
def testPadEmpty(self): def testPadEmpty(self):
arr = np.arange(6).reshape(2, 3) arr = np.arange(6).reshape(2, 3)
@ -1670,7 +1660,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(NotImplementedError, match): with self.assertRaisesRegex(NotImplementedError, match):
jnp.pad(arr, pad_width, mode) jnp.pad(arr, pad_width, mode)
@unittest.skipIf(numpy_version < (1, 17, 0), "function mode is new in numpy 1.17.0")
def testPadFunction(self): def testPadFunction(self):
def np_pad_with(vector, pad_width, iaxis, kwargs): def np_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10) pad_value = kwargs.get('padder', 10)
@ -2894,7 +2883,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
jnp.ones((-1, 1)) jnp.ones((-1, 1))
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format( "testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
jtu.format_shape_dtype_string(shape, in_dtype), jtu.format_shape_dtype_string(shape, in_dtype),
@ -2921,7 +2909,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format( {"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype), func, jtu.format_shape_dtype_string(shape, in_dtype),
@ -2944,7 +2931,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format( {"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype), func, jtu.format_shape_dtype_string(shape, in_dtype),
@ -3949,8 +3935,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for shape in [(1, 2, 3, 4)] for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1])) for axis in [None, 0, 1, -2, -1]))
def testPackbits(self, shape, dtype, axis, bitorder): def testPackbits(self, shape, dtype, axis, bitorder):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = jtu.rand_some_zero(self.rng()) rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
@ -3969,8 +3953,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for axis in [None, 0, 1, -2, -1] for axis in [None, 0, 1, -2, -1]
for count in [None, 20])) for count in [None, 20]))
def testUnpackbits(self, shape, dtype, axis, bitorder, count): def testUnpackbits(self, shape, dtype, axis, bitorder, count):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = jtu.rand_int(self.rng(), 0, 256) rng = jtu.rand_int(self.rng(), 0, 256)
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder) jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
@ -4144,14 +4126,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for sparse in [True, False])) for sparse in [True, False]))
def testIndices(self, dimensions, dtype, sparse): def testIndices(self, dimensions, dtype, sparse):
def args_maker(): return [] def args_maker(): return []
if numpy_version < (1, 17): np_fun = partial(np.indices, dimensions=dimensions,
if sparse: dtype=dtype, sparse=sparse)
raise SkipTest("indices does not have sparse on numpy < 1.17")
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype)
else:
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
jnp_fun = partial(jnp.indices, dimensions=dimensions, jnp_fun = partial(jnp.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse) dtype=dtype, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)