mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Drop support for NumPy 1.16.
This commit is contained in:
parent
87a533e4ea
commit
b130257ee1
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -52,7 +52,7 @@ jobs:
|
||||
os: ubuntu-latest
|
||||
enable-x64: 1
|
||||
# Test with numpy version that matches Google-internal version
|
||||
package-overrides: "numpy==1.16.4 scipy==1.2.1"
|
||||
package-overrides: "numpy==1.17.5 scipy==1.2.1"
|
||||
num_generated_cases: 10
|
||||
- name-prefix: "with 3.7"
|
||||
python-version: 3.7
|
||||
|
@ -13,6 +13,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* New features:
|
||||
|
||||
* Breaking changes:
|
||||
* Support for NumPy 1.16 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
|
||||
* Bug fixes:
|
||||
* Fixed bug that prevented round-tripping from JAX to TF and back:
|
||||
@ -33,8 +35,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
in TF ops. The code that XLA generates after jax2tf
|
||||
has the same location information as JAX/XLA.
|
||||
|
||||
* Breaking changes:
|
||||
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
|
||||
for Python scalars and for choosing 32-bit vs. 64-bit computations
|
||||
|
@ -78,8 +78,8 @@ def check_numpy_version(python_bin_path):
|
||||
version = shell(
|
||||
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
||||
numpy_version = tuple(map(int, version.split('.')[:2]))
|
||||
if numpy_version < (1, 16):
|
||||
print("ERROR: JAX requires NumPy 1.16 or newer, found " + version + ".")
|
||||
if numpy_version < (1, 17):
|
||||
print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
|
||||
sys.exit(-1)
|
||||
return version
|
||||
|
||||
|
@ -32,7 +32,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
||||
python_requires='>=3.6',
|
||||
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
package_data={
|
||||
|
2
setup.py
2
setup.py
@ -34,7 +34,7 @@ setup(
|
||||
package_data={'jax': ['py.typed']},
|
||||
python_requires='>=3.6',
|
||||
install_requires=[
|
||||
'numpy >=1.12',
|
||||
'numpy>=1.17',
|
||||
'absl-py',
|
||||
'opt_einsum',
|
||||
],
|
||||
|
@ -50,7 +50,7 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
numpy_version = tuple(map(int, np.version.version.split('.')))
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')))
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
@ -140,6 +140,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
|
||||
check_dtypes=False),
|
||||
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
@ -198,13 +200,6 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
inexact=True, tolerance={np.float64: 1e-9}),
|
||||
]
|
||||
|
||||
# Skip np.i0() tests on older numpy: https://github.com/numpy/numpy/issues/11205
|
||||
if numpy_version >= (1, 17, 0):
|
||||
JAX_ONE_TO_ONE_OP_RECORDS.append(
|
||||
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
|
||||
check_dtypes=False),
|
||||
)
|
||||
|
||||
JAX_COMPOUND_OP_RECORDS = [
|
||||
# angle has inconsistent 32/64-bit return types across numpy versions.
|
||||
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
|
||||
@ -821,7 +816,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 17), "where parameter not supported in older numpy")
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
|
||||
@ -1567,9 +1561,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol={np.float32: 1e-3, np.complex64: 1e-3})
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 16, 6),
|
||||
"numpy <= 1.16.5 has a bug in linear_ramp")
|
||||
# https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
|
||||
@ -1618,7 +1609,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 17, 0), "empty mode is new in numpy 1.17.0")
|
||||
def testPadEmpty(self):
|
||||
arr = np.arange(6).reshape(2, 3)
|
||||
|
||||
@ -1670,7 +1660,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, match):
|
||||
jnp.pad(arr, pad_width, mode)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 17, 0), "function mode is new in numpy 1.17.0")
|
||||
def testPadFunction(self):
|
||||
def np_pad_with(vector, pad_width, iaxis, kwargs):
|
||||
pad_value = kwargs.get('padder', 10)
|
||||
@ -2894,7 +2883,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
jnp.ones((-1, 1))
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, in_dtype),
|
||||
@ -2921,7 +2909,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
|
||||
func, jtu.format_shape_dtype_string(shape, in_dtype),
|
||||
@ -2944,7 +2931,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
|
||||
func, jtu.format_shape_dtype_string(shape, in_dtype),
|
||||
@ -3949,8 +3935,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for shape in [(1, 2, 3, 4)]
|
||||
for axis in [None, 0, 1, -2, -1]))
|
||||
def testPackbits(self, shape, dtype, axis, bitorder):
|
||||
if numpy_version < (1, 17, 0):
|
||||
raise SkipTest("bitorder arg added in numpy 1.17.0")
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
|
||||
@ -3969,8 +3953,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for axis in [None, 0, 1, -2, -1]
|
||||
for count in [None, 20]))
|
||||
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
|
||||
if numpy_version < (1, 17, 0):
|
||||
raise SkipTest("bitorder arg added in numpy 1.17.0")
|
||||
rng = jtu.rand_int(self.rng(), 0, 256)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
|
||||
@ -4144,14 +4126,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for sparse in [True, False]))
|
||||
def testIndices(self, dimensions, dtype, sparse):
|
||||
def args_maker(): return []
|
||||
if numpy_version < (1, 17):
|
||||
if sparse:
|
||||
raise SkipTest("indices does not have sparse on numpy < 1.17")
|
||||
np_fun = partial(np.indices, dimensions=dimensions,
|
||||
dtype=dtype)
|
||||
else:
|
||||
np_fun = partial(np.indices, dimensions=dimensions,
|
||||
dtype=dtype, sparse=sparse)
|
||||
np_fun = partial(np.indices, dimensions=dimensions,
|
||||
dtype=dtype, sparse=sparse)
|
||||
jnp_fun = partial(jnp.indices, dimensions=dimensions,
|
||||
dtype=dtype, sparse=sparse)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user