mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests. * don't cast np.empty(), which can lead to cast errors. * NumPy 1.24 now warns on overflowing scalar int to array casts in more places.
This commit is contained in:
parent
da285b6536
commit
73de02d5ce
@ -1060,7 +1060,7 @@ def _get_monoid_reducer(monoid_op: Callable,
|
||||
return None
|
||||
|
||||
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
return np.array(-1, dtype)
|
||||
return np.array(-1).astype(dtype)
|
||||
|
||||
def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
return np.array(0, dtype)
|
||||
|
@ -1127,3 +1127,13 @@ def strict_promotion_if_dtypes_match(dtypes):
|
||||
if all(dtype == dtypes[0] for dtype in dtypes):
|
||||
return jax.numpy_dtype_promotion('strict')
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
||||
_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?")
|
||||
def _parse_version(v: str) -> Tuple[int, ...]:
|
||||
m = _version_regex.match(v)
|
||||
if m is None:
|
||||
raise ValueError(f"Unable to parse version '{v}'")
|
||||
return tuple(int(x) for x in m.group(1).split('.'))
|
||||
|
||||
def numpy_version():
|
||||
return _parse_version(np.__version__)
|
||||
|
@ -78,8 +78,7 @@ FLAGS = config.FLAGS
|
||||
|
||||
|
||||
python_version = (sys.version_info[0], sys.version_info[1])
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
def _check_instance(self, x):
|
||||
if config.jax_array:
|
||||
|
@ -25,7 +25,7 @@ from jax._src import test_util as jtu
|
||||
|
||||
import numpy as np
|
||||
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -35,7 +35,6 @@ FLAGS = config.FLAGS
|
||||
|
||||
|
||||
python_version = (sys.version_info[0], sys.version_info[1])
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/12291): Enable jax.Array
|
||||
|
@ -32,7 +32,7 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
|
@ -52,7 +52,7 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
@ -2315,7 +2315,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertEqual(out_int64.dtype, np.int64)
|
||||
else:
|
||||
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"):
|
||||
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="NumPy will stop allowing conversion.*"):
|
||||
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=inexact_dtypes,
|
||||
@ -2432,7 +2434,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if numpy_version < (1, 24) or op == "dstack":
|
||||
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
|
||||
else:
|
||||
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype)
|
||||
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
|
||||
casting='unsafe')
|
||||
|
||||
jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
@ -3051,19 +3054,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
|
||||
|
||||
def testArrayFromList(self):
|
||||
int_max = jnp.iinfo(jnp.int64).max
|
||||
int_min = jnp.iinfo(jnp.int64).min
|
||||
dtype = dtypes.canonicalize_dtype('int64')
|
||||
int_max = jnp.iinfo(dtype).max
|
||||
int_min = jnp.iinfo(dtype).min
|
||||
|
||||
# Values at extremes are converted correctly.
|
||||
for val in [int_min, 0, int_max]:
|
||||
self.assertEqual(jnp.array([val]).dtype, dtypes.canonicalize_dtype('int64'))
|
||||
self.assertEqual(jnp.array([val]).dtype, dtype)
|
||||
|
||||
# list of values results in promoted type.
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
self.assertEqual(jnp.array([0, np.float16(1)]).dtype, jnp.result_type('int64', 'float16'))
|
||||
|
||||
# out of bounds leads to an OverflowError
|
||||
val = int_min - 1
|
||||
val = jnp.iinfo(jnp.int64).min - 1
|
||||
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
|
||||
jnp.array([0, val])
|
||||
|
||||
|
@ -1650,7 +1650,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small)
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val).astype(dtype)
|
||||
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
|
||||
args_maker = lambda: [rng(shape, dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
@ -2294,7 +2294,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
instance_shape.insert(concat_axis, pmap_dim_id)
|
||||
expected_shape = (split_axis_id, vmap_dim_id, *instance_shape)
|
||||
|
||||
x = np.empty(start_shape)
|
||||
x = np.ones(start_shape)
|
||||
self.assertEqual(reference(x, split_axis, concat_axis, vmap_axis).shape,
|
||||
expected_shape)
|
||||
|
||||
|
@ -1231,7 +1231,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
sigma = jnp.ones((2, 2))
|
||||
key = jax.random.PRNGKey(0)
|
||||
result = jax.random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
|
||||
self.assertAllClose(result[:, 0], result[:, 1])
|
||||
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
|
||||
|
||||
# Cholesky fails for singular inputs.
|
||||
if method == 'cholesky':
|
||||
|
@ -31,7 +31,7 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
||||
numpy_version = tuple(map(int, np.version.version.split('.')[:3]))
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user