disable rank promotion for jax scipy tests

This commit is contained in:
Jake VanderPlas 2021-08-04 10:44:23 -07:00
parent 5c3fc6c9f0
commit 30ea76cb6b
5 changed files with 60 additions and 3 deletions

View File

@ -64,6 +64,16 @@ def rand_sym_pos_def(rng, shape, dtype):
class LaxBackedScipyTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()
def _fetch_preconditioner(self, preconditioner, A, rng=None):
"""
Returns one of various preconditioning matrices depending on the identifier

View File

@ -144,6 +144,15 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
class LaxBackedScipyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Scipy implementation."""
def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@ -165,8 +174,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
max(len(shape) for shape in shapes))
for keepdims in [False, True]
for return_sign in [False, True]))
@jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered in .*")
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogSumExp(self, shapes, dtype, axis,
keepdims, return_sign, use_b):
if jtu.device_under_test() != "cpu":
@ -226,6 +235,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
for rec in JAX_SPECIAL_FUNCTION_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
test_autodiff, nondiff_argnums):
if (jtu.device_under_test() == "cpu" and
@ -592,7 +602,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
return
S_expected = np.linalg.svd(A, compute_uv=False)
U, S, V = jax._src.scipy.eigh.svd(A)
recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST)
recon = jnp.dot((U * jnp.expand_dims(S, 0)), V,
precision=lax.Precision.HIGHEST)
eps = jnp.finfo(dtype).eps
eps = eps * jnp.linalg.norm(A) * 10
self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)

View File

@ -59,6 +59,15 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
class NdimageTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".format(
jtu.format_shape_dtype_string(shape, dtype),

View File

@ -66,6 +66,15 @@ def zakharovFromIndices(x, ii):
class TestBFGS(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
"maxiter": maxiter, "func_and_init": func_and_init}
@ -142,6 +151,15 @@ class TestBFGS(jtu.JaxTestCase):
class TestLBFGS(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
"maxiter": maxiter, "func_and_init": func_and_init}

View File

@ -38,6 +38,15 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
class LaxBackedScipySignalTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""
def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
op,