Make JAX tests pass under NumPy 1.24.0rc2.

* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
This commit is contained in:
Peter Hawkins 2022-12-08 19:40:56 +00:00
parent da285b6536
commit 73de02d5ce
11 changed files with 29 additions and 17 deletions

View File

@ -1060,7 +1060,7 @@ def _get_monoid_reducer(monoid_op: Callable,
return None
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
return np.array(-1, dtype)
return np.array(-1).astype(dtype)
def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray:
return np.array(0, dtype)

View File

@ -1127,3 +1127,13 @@ def strict_promotion_if_dtypes_match(dtypes):
if all(dtype == dtypes[0] for dtype in dtypes):
return jax.numpy_dtype_promotion('strict')
return jax.numpy_dtype_promotion('standard')
_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?")
def _parse_version(v: str) -> Tuple[int, ...]:
m = _version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse version '{v}'")
return tuple(int(x) for x in m.group(1).split('.'))
def numpy_version():
return _parse_version(np.__version__)

View File

@ -78,8 +78,7 @@ FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
numpy_version = jtu.numpy_version()
def _check_instance(self, x):
if config.jax_array:

View File

@ -25,7 +25,7 @@ from jax._src import test_util as jtu
import numpy as np
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
numpy_version = jtu.numpy_version()
config.parse_flags_with_absl()

View File

@ -35,7 +35,6 @@ FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
# TODO(https://github.com/google/jax/issues/12291): Enable jax.Array

View File

@ -32,7 +32,7 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
numpy_version = jtu.numpy_version()
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes

View File

@ -52,7 +52,7 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
numpy_version = jtu.numpy_version()
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
@ -2315,7 +2315,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertEqual(out_int64.dtype, np.int64)
else:
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"):
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
with jtu.ignore_warning(category=DeprecationWarning,
message="NumPy will stop allowing conversion.*"):
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
@jtu.sample_product(
dtype=inexact_dtypes,
@ -2432,7 +2434,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if numpy_version < (1, 24) or op == "dstack":
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
else:
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype)
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
casting='unsafe')
jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
@ -3051,19 +3054,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
def testArrayFromList(self):
int_max = jnp.iinfo(jnp.int64).max
int_min = jnp.iinfo(jnp.int64).min
dtype = dtypes.canonicalize_dtype('int64')
int_max = jnp.iinfo(dtype).max
int_min = jnp.iinfo(dtype).min
# Values at extremes are converted correctly.
for val in [int_min, 0, int_max]:
self.assertEqual(jnp.array([val]).dtype, dtypes.canonicalize_dtype('int64'))
self.assertEqual(jnp.array([val]).dtype, dtype)
# list of values results in promoted type.
with jax.numpy_dtype_promotion('standard'):
self.assertEqual(jnp.array([0, np.float16(1)]).dtype, jnp.result_type('int64', 'float16'))
# out of bounds leads to an OverflowError
val = int_min - 1
val = jnp.iinfo(jnp.int64).min - 1
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
jnp.array([0, val])

View File

@ -1650,7 +1650,7 @@ class LaxTest(jtu.JaxTestCase):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
rng = rng_factory(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
init_val = np.asarray(init_val).astype(dtype)
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker)

View File

@ -2294,7 +2294,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
instance_shape.insert(concat_axis, pmap_dim_id)
expected_shape = (split_axis_id, vmap_dim_id, *instance_shape)
x = np.empty(start_shape)
x = np.ones(start_shape)
self.assertEqual(reference(x, split_axis, concat_axis, vmap_axis).shape,
expected_shape)

View File

@ -1231,7 +1231,7 @@ class LaxRandomTest(jtu.JaxTestCase):
sigma = jnp.ones((2, 2))
key = jax.random.PRNGKey(0)
result = jax.random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
self.assertAllClose(result[:, 0], result[:, 1])
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
# Cholesky fails for singular inputs.
if method == 'cholesky':

View File

@ -31,7 +31,7 @@ from jax.config import config
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
numpy_version = tuple(map(int, np.version.version.split('.')[:3]))
numpy_version = jtu.numpy_version()
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]