diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 0ebde27c2..1831f5358 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -52,7 +52,7 @@ jobs: os: ubuntu-latest enable-x64: 1 # Test with numpy version that matches Google-internal version - package-overrides: "numpy==1.16.4 scipy==1.2.1" + package-overrides: "numpy==1.17.5 scipy==1.2.1" num_generated_cases: 10 - name-prefix: "with 3.7" python-version: 3.7 diff --git a/CHANGELOG.md b/CHANGELOG.md index e6782f692..758553939 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * New features: * Breaking changes: + * Support for NumPy 1.16 has been dropped, per the + [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -33,8 +35,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA. -* Breaking changes: - * Bug fixes: * The {func}`jax2tf.convert` now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations diff --git a/build/build.py b/build/build.py index 1e7d0b4b8..c7cb95a13 100755 --- a/build/build.py +++ b/build/build.py @@ -78,8 +78,8 @@ def check_numpy_version(python_bin_path): version = shell( [python_bin_path, "-c", "import numpy as np; print(np.__version__)"]) numpy_version = tuple(map(int, version.split('.')[:2])) - if numpy_version < (1, 16): - print("ERROR: JAX requires NumPy 1.16 or newer, found " + version + ".") + if numpy_version < (1, 17): + print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".") sys.exit(-1) return version diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 2428ac6aa..d88317cb3 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -32,7 +32,7 @@ setup( author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension-stubs'], python_requires='>=3.6', - install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers >= 1.12, < 3.0'], + install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'], url='https://github.com/google/jax', license='Apache-2.0', package_data={ diff --git a/setup.py b/setup.py index 889b39de7..112374159 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ setup( package_data={'jax': ['py.typed']}, python_requires='>=3.6', install_requires=[ - 'numpy >=1.12', + 'numpy>=1.17', 'absl-py', 'opt_einsum', ], diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 1d3e475b2..b8f9b1d2d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -50,7 +50,7 @@ from jax.config import config config.parse_flags_with_absl() FLAGS = config.FLAGS -numpy_version = tuple(map(int, np.version.version.split('.'))) +numpy_version = tuple(map(int, np.__version__.split('.'))) nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes @@ -140,6 +140,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [ jtu.rand_default, [], check_dtypes=False), op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [], + check_dtypes=False), op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), @@ -198,13 +200,6 @@ JAX_ONE_TO_ONE_OP_RECORDS = [ inexact=True, tolerance={np.float64: 1e-9}), ] -# Skip np.i0() tests on older numpy: https://github.com/numpy/numpy/issues/11205 -if numpy_version >= (1, 17, 0): - JAX_ONE_TO_ONE_OP_RECORDS.append( - op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [], - check_dtypes=False), - ) - JAX_COMPOUND_OP_RECORDS = [ # angle has inconsistent 32/64-bit return types across numpy versions. op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], @@ -821,7 +816,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skipIf(numpy_version < (1, 17), "where parameter not supported in older numpy") @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( @@ -1567,9 +1561,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol={np.float32: 1e-3, np.complex64: 1e-3}) self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skipIf(numpy_version < (1, 16, 6), - "numpy <= 1.16.5 has a bug in linear_ramp") - # https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format( jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values), @@ -1618,7 +1609,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skipIf(numpy_version < (1, 17, 0), "empty mode is new in numpy 1.17.0") def testPadEmpty(self): arr = np.arange(6).reshape(2, 3) @@ -1670,7 +1660,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): with self.assertRaisesRegex(NotImplementedError, match): jnp.pad(arr, pad_width, mode) - @unittest.skipIf(numpy_version < (1, 17, 0), "function mode is new in numpy 1.17.0") def testPadFunction(self): def np_pad_with(vector, pad_width, iaxis, kwargs): pad_value = kwargs.get('padder', 10) @@ -2894,7 +2883,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): with self.assertRaises(TypeError): jnp.ones((-1, 1)) - @unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy") @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format( jtu.format_shape_dtype_string(shape, in_dtype), @@ -2921,7 +2909,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy") @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format( func, jtu.format_shape_dtype_string(shape, in_dtype), @@ -2944,7 +2931,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy") @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format( func, jtu.format_shape_dtype_string(shape, in_dtype), @@ -3949,8 +3935,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): for shape in [(1, 2, 3, 4)] for axis in [None, 0, 1, -2, -1])) def testPackbits(self, shape, dtype, axis, bitorder): - if numpy_version < (1, 17, 0): - raise SkipTest("bitorder arg added in numpy 1.17.0") rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) @@ -3969,8 +3953,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): for axis in [None, 0, 1, -2, -1] for count in [None, 20])) def testUnpackbits(self, shape, dtype, axis, bitorder, count): - if numpy_version < (1, 17, 0): - raise SkipTest("bitorder arg added in numpy 1.17.0") rng = jtu.rand_int(self.rng(), 0, 256) args_maker = lambda: [rng(shape, dtype)] jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder) @@ -4144,14 +4126,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): for sparse in [True, False])) def testIndices(self, dimensions, dtype, sparse): def args_maker(): return [] - if numpy_version < (1, 17): - if sparse: - raise SkipTest("indices does not have sparse on numpy < 1.17") - np_fun = partial(np.indices, dimensions=dimensions, - dtype=dtype) - else: - np_fun = partial(np.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) + np_fun = partial(np.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) jnp_fun = partial(jnp.indices, dimensions=dimensions, dtype=dtype, sparse=sparse) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)