From 4c0d61a1435b70760814f1f678cb041d36b8408d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 Jun 2022 13:59:53 -0700 Subject: [PATCH] Add jtu.strict_promotion_if_dtypes_match utility --- jax/_src/test_util.py | 11 +++++++ tests/lax_numpy_test.py | 60 ++++++++++++++++++--------------------- tests/scipy_stats_test.py | 56 ++++++++++++++++-------------------- tests/sparse_test.py | 12 ++------ 4 files changed, 66 insertions(+), 73 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 2437b49f7..544718a64 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -30,6 +30,7 @@ from absl.testing import parameterized import numpy as np import numpy.random as npr +import jax from jax._src import api from jax import core from jax._src import dtypes as _dtypes @@ -1053,3 +1054,13 @@ class DeprecatedBufferDonationTestCase(BufferDonationTestCase): as np.testing.assert_allclose(), which work directly with JAX arrays."""), category=DeprecationWarning) super().__init__(*args, **kwargs) + + +def strict_promotion_if_dtypes_match(dtypes): + """ + Context manager to enable strict promotion if all dtypes match, + and enable standard dtype promotion otherwise. + """ + if all(dtype == dtypes[0] for dtype in dtypes): + return jax.numpy_dtype_promotion('strict') + return jax.numpy_dtype_promotion('standard') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c57ba9800..430f2cde6 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -119,12 +119,6 @@ def _get_y_shapes(y_dtype, shape, rowvar): return [(shape[0], 1), (shape[0], 2), (shape[0], 5)] -def _strict_promotion_if_dtypes_match(dtypes): - if all(dtype == dtypes[0] for dtype in dtypes): - return jax.numpy_dtype_promotion('strict') - return jax.numpy_dtype_promotion('standard') - - OpRecord = collections.namedtuple( "OpRecord", ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", @@ -595,7 +589,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol = functools.reduce(jtu.join_tolerance, [tolerance, tol, jtu.default_tolerance()]) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op, args_maker, check_dtypes=check_dtypes, tol=tol) self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes, @@ -620,7 +614,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # jnp arrays. args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(itertools.chain.from_iterable( @@ -644,7 +638,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) fun = lambda fst, snd: getattr(snd, name)(fst) tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -738,7 +732,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes): rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) @@ -771,7 +765,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_op = getattr(np, op.__name__) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(op, args_maker) self._CheckAgainstNumpy(np_op, op, args_maker) @@ -1268,7 +1262,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} tol = max(jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec)) - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @@ -1302,7 +1296,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) @@ -1336,7 +1330,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if jtu.device_under_test() == "tpu": tol[np.float16] = tol[np.float32] = tol[np.complex64] = 4e-2 - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) @@ -1371,7 +1365,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if jtu.device_under_test() == "tpu": tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1 - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -1445,7 +1439,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - with _strict_promotion_if_dtypes_match([dtype1, dtype2]): + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1472,7 +1466,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) def jnp_fun(arg1, arg2): return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) - with _strict_promotion_if_dtypes_match([dtype1, dtype2]): + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -1491,7 +1485,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(arg1, arg2): dtype = jnp.promote_types(arg1.dtype, arg2.dtype) return np.union1d(arg1, arg2).astype(dtype) - with _strict_promotion_if_dtypes_match([dtype1, dtype2]): + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1519,7 +1513,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) def jnp_fun(arg1, arg2): return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) - with _strict_promotion_if_dtypes_match([dtype1, dtype2]): + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -1545,7 +1539,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): ar1 = np.ravel(ar1) ar2 = np.ravel(ar2) return np.setxor1d(ar1, ar2, assume_unique) - with _strict_promotion_if_dtypes_match([dtype1, dtype2]): + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( @@ -1567,7 +1561,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - with _strict_promotion_if_dtypes_match([dtype1, dtype2]): + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @@ -1599,7 +1593,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol = max(jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec)) # TODO(phawkins): there are float32/float64 disagreements for some inputs. - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) @@ -2280,7 +2274,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - with _strict_promotion_if_dtypes_match(arg_dtypes): + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -2329,7 +2323,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - with _strict_promotion_if_dtypes_match(arg_dtypes): + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -3075,7 +3069,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)), rng((20,), target_dtype)] - with _strict_promotion_if_dtypes_match([dtype, target_dtype]): + with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]): # skip numpy comparison for integer types with period specified, because numpy # uses an unstable sort and so results differ for duplicate values. if not (period and np.issubdtype(dtype, np.integer)): @@ -3262,7 +3256,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] np_fun = _promote_like_jnp(np.column_stack) jnp_fun = jnp.column_stack - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -3288,7 +3282,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] np_fun = _promote_like_jnp(partial(np.stack, axis=axis)) jnp_fun = partial(jnp.stack, axis=axis) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -3314,7 +3308,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] np_fun = _promote_like_jnp(getattr(np, op)) jnp_fun = getattr(jnp, op) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -4973,7 +4967,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, shapes, dtypes) def np_fun(cond, x, y): return _promote_like_jnp(partial(np.where, cond))(x, y) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(np_fun, jnp.where, args_maker) self._CompileAndCheck(jnp.where, args_maker) @@ -5007,7 +5001,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): return np.select(condlist, [np.asarray(x, dtype=dtype) for x in choicelist], np.asarray(default, dtype=dtype)) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, check_dtypes=False) self._CompileAndCheck(jnp.select, args_maker, @@ -6118,7 +6112,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): else: tol = {np.complex64: 1e-5, np.complex128: 1e-14} - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol) self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol) @@ -6143,7 +6137,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): else: tol = {np.complex64: 1e-5, np.complex128: 1e-14} - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol) self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) @@ -6467,7 +6461,7 @@ class NumpyUfuncTests(jtu.JaxTestCase): message="divide by zero.*")(np_op) args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) - with _strict_promotion_if_dtypes_match(arg_dtypes): + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): try: jnp_op(*args_maker()) except NotImplementedError: diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index dcad58481..2dfcde31d 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -34,12 +34,6 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] scipy_version = tuple(map(int, osp.version.version.split('.')[:2])) -def _strict_promotion_if_dtypes_match(dtypes): - if all(dtype == dtypes[0] for dtype in dtypes): - return jax.numpy_dtype_promotion('strict') - return jax.numpy_dtype_promotion('standard') - - def genNamedParametersNArgs(n): return parameterized.named_parameters( jtu.cases_from_list( @@ -68,7 +62,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [k, mu, loc] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @@ -87,7 +81,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [k, mu, loc] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @@ -104,7 +98,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) return [k, mu, loc] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @@ -123,7 +117,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [x, p, loc] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -141,7 +135,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [x, p, loc] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -156,7 +150,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, @@ -181,7 +175,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -222,7 +216,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol = {np.float32: 1E-3, np.float64: 1e-5} - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @@ -237,7 +231,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -252,7 +246,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -272,7 +266,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, p = map(rng, shapes, dtypes) return [x, p] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @@ -287,7 +281,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, p = map(rng, shapes, dtypes) return [x, p] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @@ -308,7 +302,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol = {np.float32: 1e-6, np.float64: 1e-8} - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) @@ -325,7 +319,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -342,7 +336,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) @@ -356,7 +350,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): return list(map(rng, shapes, dtypes)) - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @@ -419,7 +413,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @@ -437,7 +431,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -455,7 +449,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @@ -475,7 +469,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [q, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @@ -490,7 +484,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @@ -508,7 +502,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, df, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, @@ -525,7 +519,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -540,7 +534,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @@ -559,7 +553,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [k, n, a, b, loc] - with _strict_promotion_if_dtypes_match(dtypes): + with jtu.strict_promotion_if_dtypes_match(dtypes): if scipy_version >= (1, 4): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 9841981be..c2c175c86 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -109,12 +109,6 @@ def _generate_bcoo_dot_general_properties(shapes, dtypes) -> BcooDotGeneralPrope ) -def _strict_promotion_if_dtypes_match(dtypes): - if all(dtype == dtypes[0] for dtype in dtypes): - return jax.numpy_dtype_promotion('strict') - return jax.numpy_dtype_promotion('standard') - - all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex @@ -1846,7 +1840,7 @@ class BCOOTest(jtu.JaxTestCase): lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=max(0, len(lhs_shape) - 2)) rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=max(0, len(rhs_shape) - 2)) - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): out1 = lhs @ rhs out2 = lhs_sp @ rhs out3 = lhs @ rhs_sp @@ -1881,7 +1875,7 @@ class BCOOTest(jtu.JaxTestCase): sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense) - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): out1 = lhs * rhs out2 = (sp(lhs) * rhs).todense() out3 = (rhs * sp(lhs)).todense() @@ -1917,7 +1911,7 @@ class BCOOTest(jtu.JaxTestCase): lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense) rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense) - with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): out1 = lhs * rhs out2 = (lhs_sp * rhs_sp).todense()