mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add jtu.strict_promotion_if_dtypes_match utility
This commit is contained in:
parent
63755156ea
commit
4c0d61a143
@ -30,6 +30,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax._src import dtypes as _dtypes
|
||||
@ -1053,3 +1054,13 @@ class DeprecatedBufferDonationTestCase(BufferDonationTestCase):
|
||||
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
|
||||
category=DeprecationWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def strict_promotion_if_dtypes_match(dtypes):
|
||||
"""
|
||||
Context manager to enable strict promotion if all dtypes match,
|
||||
and enable standard dtype promotion otherwise.
|
||||
"""
|
||||
if all(dtype == dtypes[0] for dtype in dtypes):
|
||||
return jax.numpy_dtype_promotion('strict')
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
@ -119,12 +119,6 @@ def _get_y_shapes(y_dtype, shape, rowvar):
|
||||
return [(shape[0], 1), (shape[0], 2), (shape[0], 5)]
|
||||
|
||||
|
||||
def _strict_promotion_if_dtypes_match(dtypes):
|
||||
if all(dtype == dtypes[0] for dtype in dtypes):
|
||||
return jax.numpy_dtype_promotion('strict')
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord",
|
||||
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
|
||||
@ -595,7 +589,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol = functools.reduce(jtu.join_tolerance,
|
||||
[tolerance, tol, jtu.default_tolerance()])
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op,
|
||||
args_maker, check_dtypes=check_dtypes, tol=tol)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes,
|
||||
@ -620,7 +614,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# jnp arrays.
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
|
||||
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
@ -644,7 +638,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
|
||||
fun = lambda fst, snd: getattr(snd, name)(fst)
|
||||
tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -738,7 +732,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@ -771,7 +765,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
np_op = getattr(np, op.__name__)
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
self._CheckAgainstNumpy(np_op, op, args_maker)
|
||||
|
||||
@ -1268,7 +1262,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
|
||||
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
|
||||
jtu.tolerance(rhs_dtype, tol_spec))
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@ -1302,7 +1296,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
|
||||
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
|
||||
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@ -1336,7 +1330,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol[np.float16] = tol[np.float32] = tol[np.complex64] = 4e-2
|
||||
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@ -1371,7 +1365,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1
|
||||
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -1445,7 +1439,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
||||
|
||||
with _strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1472,7 +1466,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0)
|
||||
def jnp_fun(arg1, arg2):
|
||||
return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value)
|
||||
with _strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -1491,7 +1485,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(arg1, arg2):
|
||||
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
|
||||
return np.union1d(arg1, arg2).astype(dtype)
|
||||
with _strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1519,7 +1513,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
return np.concatenate([result, np.full(size - len(result), fv, result.dtype)])
|
||||
def jnp_fun(arg1, arg2):
|
||||
return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value)
|
||||
with _strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -1545,7 +1539,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
ar1 = np.ravel(ar1)
|
||||
ar2 = np.ravel(ar2)
|
||||
return np.setxor1d(ar1, ar2, assume_unique)
|
||||
with _strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1567,7 +1561,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
||||
jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
|
||||
np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
|
||||
with _strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
|
||||
@ -1599,7 +1593,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
|
||||
jtu.tolerance(rhs_dtype, tol_spec))
|
||||
# TODO(phawkins): there are float32/float64 disagreements for some inputs.
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
@ -2280,7 +2274,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -2329,7 +2323,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -3075,7 +3069,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)),
|
||||
rng((20,), target_dtype)]
|
||||
|
||||
with _strict_promotion_if_dtypes_match([dtype, target_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]):
|
||||
# skip numpy comparison for integer types with period specified, because numpy
|
||||
# uses an unstable sort and so results differ for duplicate values.
|
||||
if not (period and np.issubdtype(dtype, np.integer)):
|
||||
@ -3262,7 +3256,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(np.column_stack)
|
||||
jnp_fun = jnp.column_stack
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -3288,7 +3282,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
|
||||
jnp_fun = partial(jnp.stack, axis=axis)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -3314,7 +3308,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(getattr(np, op))
|
||||
jnp_fun = getattr(jnp, op)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -4973,7 +4967,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
def np_fun(cond, x, y):
|
||||
return _promote_like_jnp(partial(np.where, cond))(x, y)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker)
|
||||
self._CompileAndCheck(jnp.where, args_maker)
|
||||
|
||||
@ -5007,7 +5001,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
return np.select(condlist,
|
||||
[np.asarray(x, dtype=dtype) for x in choicelist],
|
||||
np.asarray(default, dtype=dtype))
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(np_fun, jnp.select, args_maker,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jnp.select, args_maker,
|
||||
@ -6118,7 +6112,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@ -6143,7 +6137,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@ -6467,7 +6461,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
message="divide by zero.*")(np_op)
|
||||
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
||||
|
||||
with _strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
try:
|
||||
jnp_op(*args_maker())
|
||||
except NotImplementedError:
|
||||
|
@ -34,12 +34,6 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
scipy_version = tuple(map(int, osp.version.version.split('.')[:2]))
|
||||
|
||||
|
||||
def _strict_promotion_if_dtypes_match(dtypes):
|
||||
if all(dtype == dtypes[0] for dtype in dtypes):
|
||||
return jax.numpy_dtype_promotion('strict')
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
||||
|
||||
def genNamedParametersNArgs(n):
|
||||
return parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -68,7 +62,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
|
||||
@ -87,7 +81,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -104,7 +98,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
|
||||
return [k, mu, loc]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -123,7 +117,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -141,7 +135,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -156,7 +150,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
@ -181,7 +175,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -222,7 +216,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
tol = {np.float32: 1E-3, np.float64: 1e-5}
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
|
||||
@ -237,7 +231,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -252,7 +246,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -272,7 +266,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, p = map(rng, shapes, dtypes)
|
||||
return [x, p]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -287,7 +281,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, p = map(rng, shapes, dtypes)
|
||||
return [x, p]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -308,7 +302,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
tol = {np.float32: 1e-6, np.float64: 1e-8}
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
@ -325,7 +319,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -342,7 +336,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={np.float32: 1e-5, np.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -356,7 +350,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -419,7 +413,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -437,7 +431,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -455,7 +449,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -475,7 +469,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@ -490,7 +484,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, b, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -508,7 +502,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
@ -525,7 +519,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, np.abs(scale)]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -540,7 +534,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
@ -559,7 +553,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [k, n, a, b, loc]
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
if scipy_version >= (1, 4):
|
||||
scipy_fun = osp_stats.betabinom.logpmf
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
|
@ -109,12 +109,6 @@ def _generate_bcoo_dot_general_properties(shapes, dtypes) -> BcooDotGeneralPrope
|
||||
)
|
||||
|
||||
|
||||
def _strict_promotion_if_dtypes_match(dtypes):
|
||||
if all(dtype == dtypes[0] for dtype in dtypes):
|
||||
return jax.numpy_dtype_promotion('strict')
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
||||
|
||||
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
||||
|
||||
|
||||
@ -1846,7 +1840,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=max(0, len(lhs_shape) - 2))
|
||||
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=max(0, len(rhs_shape) - 2))
|
||||
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
out1 = lhs @ rhs
|
||||
out2 = lhs_sp @ rhs
|
||||
out3 = lhs @ rhs_sp
|
||||
@ -1881,7 +1875,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
|
||||
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
out1 = lhs * rhs
|
||||
out2 = (sp(lhs) * rhs).todense()
|
||||
out3 = (rhs * sp(lhs)).todense()
|
||||
@ -1917,7 +1911,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense)
|
||||
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense)
|
||||
|
||||
with _strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
out1 = lhs * rhs
|
||||
out2 = (lhs_sp * rhs_sp).todense()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user