mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Drop support for NumPy 1.16.
This commit is contained in:
parent
87a533e4ea
commit
b130257ee1
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -52,7 +52,7 @@ jobs:
|
|||||||
os: ubuntu-latest
|
os: ubuntu-latest
|
||||||
enable-x64: 1
|
enable-x64: 1
|
||||||
# Test with numpy version that matches Google-internal version
|
# Test with numpy version that matches Google-internal version
|
||||||
package-overrides: "numpy==1.16.4 scipy==1.2.1"
|
package-overrides: "numpy==1.17.5 scipy==1.2.1"
|
||||||
num_generated_cases: 10
|
num_generated_cases: 10
|
||||||
- name-prefix: "with 3.7"
|
- name-prefix: "with 3.7"
|
||||||
python-version: 3.7
|
python-version: 3.7
|
||||||
|
@ -13,6 +13,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* New features:
|
* New features:
|
||||||
|
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
|
* Support for NumPy 1.16 has been dropped, per the
|
||||||
|
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||||
|
|
||||||
* Bug fixes:
|
* Bug fixes:
|
||||||
* Fixed bug that prevented round-tripping from JAX to TF and back:
|
* Fixed bug that prevented round-tripping from JAX to TF and back:
|
||||||
@ -33,8 +35,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
in TF ops. The code that XLA generates after jax2tf
|
in TF ops. The code that XLA generates after jax2tf
|
||||||
has the same location information as JAX/XLA.
|
has the same location information as JAX/XLA.
|
||||||
|
|
||||||
* Breaking changes:
|
|
||||||
|
|
||||||
* Bug fixes:
|
* Bug fixes:
|
||||||
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
|
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
|
||||||
for Python scalars and for choosing 32-bit vs. 64-bit computations
|
for Python scalars and for choosing 32-bit vs. 64-bit computations
|
||||||
|
@ -78,8 +78,8 @@ def check_numpy_version(python_bin_path):
|
|||||||
version = shell(
|
version = shell(
|
||||||
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
||||||
numpy_version = tuple(map(int, version.split('.')[:2]))
|
numpy_version = tuple(map(int, version.split('.')[:2]))
|
||||||
if numpy_version < (1, 16):
|
if numpy_version < (1, 17):
|
||||||
print("ERROR: JAX requires NumPy 1.16 or newer, found " + version + ".")
|
print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ setup(
|
|||||||
author_email='jax-dev@google.com',
|
author_email='jax-dev@google.com',
|
||||||
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.6',
|
||||||
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||||
url='https://github.com/google/jax',
|
url='https://github.com/google/jax',
|
||||||
license='Apache-2.0',
|
license='Apache-2.0',
|
||||||
package_data={
|
package_data={
|
||||||
|
2
setup.py
2
setup.py
@ -34,7 +34,7 @@ setup(
|
|||||||
package_data={'jax': ['py.typed']},
|
package_data={'jax': ['py.typed']},
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.6',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'numpy >=1.12',
|
'numpy>=1.17',
|
||||||
'absl-py',
|
'absl-py',
|
||||||
'opt_einsum',
|
'opt_einsum',
|
||||||
],
|
],
|
||||||
|
@ -50,7 +50,7 @@ from jax.config import config
|
|||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
FLAGS = config.FLAGS
|
FLAGS = config.FLAGS
|
||||||
|
|
||||||
numpy_version = tuple(map(int, np.version.version.split('.')))
|
numpy_version = tuple(map(int, np.__version__.split('.')))
|
||||||
|
|
||||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||||
@ -140,6 +140,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
|||||||
jtu.rand_default, [], check_dtypes=False),
|
jtu.rand_default, [], check_dtypes=False),
|
||||||
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||||
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||||
|
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
|
||||||
|
check_dtypes=False),
|
||||||
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
|
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
|
||||||
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||||
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||||
@ -198,13 +200,6 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
|||||||
inexact=True, tolerance={np.float64: 1e-9}),
|
inexact=True, tolerance={np.float64: 1e-9}),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Skip np.i0() tests on older numpy: https://github.com/numpy/numpy/issues/11205
|
|
||||||
if numpy_version >= (1, 17, 0):
|
|
||||||
JAX_ONE_TO_ONE_OP_RECORDS.append(
|
|
||||||
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
|
|
||||||
check_dtypes=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
JAX_COMPOUND_OP_RECORDS = [
|
JAX_COMPOUND_OP_RECORDS = [
|
||||||
# angle has inconsistent 32/64-bit return types across numpy versions.
|
# angle has inconsistent 32/64-bit return types across numpy versions.
|
||||||
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
|
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
|
||||||
@ -821,7 +816,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 17), "where parameter not supported in older numpy")
|
|
||||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||||
jtu.cases_from_list(
|
jtu.cases_from_list(
|
||||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
|
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
|
||||||
@ -1567,9 +1561,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
tol={np.float32: 1e-3, np.complex64: 1e-3})
|
tol={np.float32: 1e-3, np.complex64: 1e-3})
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 16, 6),
|
|
||||||
"numpy <= 1.16.5 has a bug in linear_ramp")
|
|
||||||
# https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
|
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
|
||||||
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
|
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
|
||||||
@ -1618,7 +1609,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 17, 0), "empty mode is new in numpy 1.17.0")
|
|
||||||
def testPadEmpty(self):
|
def testPadEmpty(self):
|
||||||
arr = np.arange(6).reshape(2, 3)
|
arr = np.arange(6).reshape(2, 3)
|
||||||
|
|
||||||
@ -1670,7 +1660,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(NotImplementedError, match):
|
with self.assertRaisesRegex(NotImplementedError, match):
|
||||||
jnp.pad(arr, pad_width, mode)
|
jnp.pad(arr, pad_width, mode)
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 17, 0), "function mode is new in numpy 1.17.0")
|
|
||||||
def testPadFunction(self):
|
def testPadFunction(self):
|
||||||
def np_pad_with(vector, pad_width, iaxis, kwargs):
|
def np_pad_with(vector, pad_width, iaxis, kwargs):
|
||||||
pad_value = kwargs.get('padder', 10)
|
pad_value = kwargs.get('padder', 10)
|
||||||
@ -2894,7 +2883,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
jnp.ones((-1, 1))
|
jnp.ones((-1, 1))
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
|
|
||||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||||
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
|
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
|
||||||
jtu.format_shape_dtype_string(shape, in_dtype),
|
jtu.format_shape_dtype_string(shape, in_dtype),
|
||||||
@ -2921,7 +2909,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
|
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
|
||||||
func, jtu.format_shape_dtype_string(shape, in_dtype),
|
func, jtu.format_shape_dtype_string(shape, in_dtype),
|
||||||
@ -2944,7 +2931,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
|
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
|
||||||
func, jtu.format_shape_dtype_string(shape, in_dtype),
|
func, jtu.format_shape_dtype_string(shape, in_dtype),
|
||||||
@ -3949,8 +3935,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
for shape in [(1, 2, 3, 4)]
|
for shape in [(1, 2, 3, 4)]
|
||||||
for axis in [None, 0, 1, -2, -1]))
|
for axis in [None, 0, 1, -2, -1]))
|
||||||
def testPackbits(self, shape, dtype, axis, bitorder):
|
def testPackbits(self, shape, dtype, axis, bitorder):
|
||||||
if numpy_version < (1, 17, 0):
|
|
||||||
raise SkipTest("bitorder arg added in numpy 1.17.0")
|
|
||||||
rng = jtu.rand_some_zero(self.rng())
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
|
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
|
||||||
@ -3969,8 +3953,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
for axis in [None, 0, 1, -2, -1]
|
for axis in [None, 0, 1, -2, -1]
|
||||||
for count in [None, 20]))
|
for count in [None, 20]))
|
||||||
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
|
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
|
||||||
if numpy_version < (1, 17, 0):
|
|
||||||
raise SkipTest("bitorder arg added in numpy 1.17.0")
|
|
||||||
rng = jtu.rand_int(self.rng(), 0, 256)
|
rng = jtu.rand_int(self.rng(), 0, 256)
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
|
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
|
||||||
@ -4144,14 +4126,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
for sparse in [True, False]))
|
for sparse in [True, False]))
|
||||||
def testIndices(self, dimensions, dtype, sparse):
|
def testIndices(self, dimensions, dtype, sparse):
|
||||||
def args_maker(): return []
|
def args_maker(): return []
|
||||||
if numpy_version < (1, 17):
|
np_fun = partial(np.indices, dimensions=dimensions,
|
||||||
if sparse:
|
dtype=dtype, sparse=sparse)
|
||||||
raise SkipTest("indices does not have sparse on numpy < 1.17")
|
|
||||||
np_fun = partial(np.indices, dimensions=dimensions,
|
|
||||||
dtype=dtype)
|
|
||||||
else:
|
|
||||||
np_fun = partial(np.indices, dimensions=dimensions,
|
|
||||||
dtype=dtype, sparse=sparse)
|
|
||||||
jnp_fun = partial(jnp.indices, dimensions=dimensions,
|
jnp_fun = partial(jnp.indices, dimensions=dimensions,
|
||||||
dtype=dtype, sparse=sparse)
|
dtype=dtype, sparse=sparse)
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user