From 73de02d5ce77d453a6db53853292bdd94ec1547a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 8 Dec 2022 19:40:56 +0000 Subject: [PATCH] Make JAX tests pass under NumPy 1.24.0rc2. * allow rc2 in numpy versions when parsed by tests. * don't cast np.empty(), which can lead to cast errors. * NumPy 1.24 now warns on overflowing scalar int to array casts in more places. --- jax/_src/lax/lax.py | 2 +- jax/_src/test_util.py | 10 ++++++++++ tests/api_test.py | 3 +-- tests/array_interoperability_test.py | 2 +- tests/dynamic_api_test.py | 1 - tests/lax_numpy_reducers_test.py | 2 +- tests/lax_numpy_test.py | 18 +++++++++++------- tests/lax_test.py | 2 +- tests/pmap_test.py | 2 +- tests/random_test.py | 2 +- tests/scipy_stats_test.py | 2 +- 11 files changed, 29 insertions(+), 17 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b4a4614a4..cc05a7f9f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1060,7 +1060,7 @@ def _get_monoid_reducer(monoid_op: Callable, return None def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray: - return np.array(-1, dtype) + return np.array(-1).astype(dtype) def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray: return np.array(0, dtype) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index cb2ccef80..ff611c1dd 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1127,3 +1127,13 @@ def strict_promotion_if_dtypes_match(dtypes): if all(dtype == dtypes[0] for dtype in dtypes): return jax.numpy_dtype_promotion('strict') return jax.numpy_dtype_promotion('standard') + +_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?") +def _parse_version(v: str) -> Tuple[int, ...]: + m = _version_regex.match(v) + if m is None: + raise ValueError(f"Unable to parse version '{v}'") + return tuple(int(x) for x in m.group(1).split('.')) + +def numpy_version(): + return _parse_version(np.__version__) diff --git a/tests/api_test.py b/tests/api_test.py index facc17e13..2fb7a353b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -78,8 +78,7 @@ FLAGS = config.FLAGS python_version = (sys.version_info[0], sys.version_info[1]) -numpy_version = tuple(map(int, np.__version__.split('.')[:3])) - +numpy_version = jtu.numpy_version() def _check_instance(self, x): if config.jax_array: diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 71d821dd0..0ec896336 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -25,7 +25,7 @@ from jax._src import test_util as jtu import numpy as np -numpy_version = tuple(map(int, np.__version__.split('.')[:3])) +numpy_version = jtu.numpy_version() config.parse_flags_with_absl() diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 151cb5f0c..1a677a2d3 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -35,7 +35,6 @@ FLAGS = config.FLAGS python_version = (sys.version_info[0], sys.version_info[1]) -numpy_version = tuple(map(int, np.__version__.split('.')[:3])) # TODO(https://github.com/google/jax/issues/12291): Enable jax.Array diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 4e100e258..e89324ce1 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -32,7 +32,7 @@ from jax.config import config config.parse_flags_with_absl() FLAGS = config.FLAGS -numpy_version = tuple(map(int, np.__version__.split('.')[:3])) +numpy_version = jtu.numpy_version() nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 60f382f0a..bb194f54b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -52,7 +52,7 @@ from jax.config import config config.parse_flags_with_absl() FLAGS = config.FLAGS -numpy_version = tuple(map(int, np.__version__.split('.')[:3])) +numpy_version = jtu.numpy_version() nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes @@ -2315,7 +2315,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertEqual(out_int64.dtype, np.int64) else: with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + with jtu.ignore_warning(category=DeprecationWarning, + message="NumPy will stop allowing conversion.*"): + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) @jtu.sample_product( dtype=inexact_dtypes, @@ -2432,7 +2434,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if numpy_version < (1, 24) or op == "dstack": np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) else: - np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype) + np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, + casting='unsafe') jnp_fun = partial(getattr(jnp, op), dtype=out_dtype) with jtu.strict_promotion_if_dtypes_match(dtypes): @@ -3051,19 +3054,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64')) def testArrayFromList(self): - int_max = jnp.iinfo(jnp.int64).max - int_min = jnp.iinfo(jnp.int64).min + dtype = dtypes.canonicalize_dtype('int64') + int_max = jnp.iinfo(dtype).max + int_min = jnp.iinfo(dtype).min # Values at extremes are converted correctly. for val in [int_min, 0, int_max]: - self.assertEqual(jnp.array([val]).dtype, dtypes.canonicalize_dtype('int64')) + self.assertEqual(jnp.array([val]).dtype, dtype) # list of values results in promoted type. with jax.numpy_dtype_promotion('standard'): self.assertEqual(jnp.array([0, np.float16(1)]).dtype, jnp.result_type('int64', 'float16')) # out of bounds leads to an OverflowError - val = int_min - 1 + val = jnp.iinfo(jnp.int64).min - 1 with self.assertRaisesRegex(OverflowError, "Python int too large.*"): jnp.array([0, val]) diff --git a/tests/lax_test.py b/tests/lax_test.py index 6622d0feb..9f49cab15 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1650,7 +1650,7 @@ class LaxTest(jtu.JaxTestCase): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) - init_val = np.asarray(init_val, dtype=dtype) + init_val = np.asarray(init_val).astype(dtype) fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims) args_maker = lambda: [rng(shape, dtype), init_val] self._CompileAndCheck(fun, args_maker) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 9c8e41d9f..7da08530b 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2294,7 +2294,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase): instance_shape.insert(concat_axis, pmap_dim_id) expected_shape = (split_axis_id, vmap_dim_id, *instance_shape) - x = np.empty(start_shape) + x = np.ones(start_shape) self.assertEqual(reference(x, split_axis, concat_axis, vmap_axis).shape, expected_shape) diff --git a/tests/random_test.py b/tests/random_test.py index 445d7e12a..26f7f65c6 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1231,7 +1231,7 @@ class LaxRandomTest(jtu.JaxTestCase): sigma = jnp.ones((2, 2)) key = jax.random.PRNGKey(0) result = jax.random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method) - self.assertAllClose(result[:, 0], result[:, 1]) + self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3) # Cholesky fails for singular inputs. if method == 'cholesky': diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 0007a099e..a5f4489a9 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -31,7 +31,7 @@ from jax.config import config config.parse_flags_with_absl() scipy_version = tuple(map(int, scipy.version.version.split('.')[:3])) -numpy_version = tuple(map(int, np.version.version.split('.')[:3])) +numpy_version = jtu.numpy_version() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]