mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
disable rank promotion for jax scipy tests
This commit is contained in:
parent
5c3fc6c9f0
commit
30ea76cb6b
@ -64,6 +64,16 @@ def rand_sym_pos_def(rng, shape, dtype):
|
||||
|
||||
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
|
||||
super().tearDown()
|
||||
|
||||
def _fetch_preconditioner(self, preconditioner, A, rng=None):
|
||||
"""
|
||||
Returns one of various preconditioning matrices depending on the identifier
|
||||
|
@ -144,6 +144,15 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed Scipy implementation."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
|
||||
super().tearDown()
|
||||
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
|
||||
@ -165,8 +174,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
max(len(shape) for shape in shapes))
|
||||
for keepdims in [False, True]
|
||||
for return_sign in [False, True]))
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value encountered in .*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*")
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testLogSumExp(self, shapes, dtype, axis,
|
||||
keepdims, return_sign, use_b):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
@ -226,6 +235,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
|
||||
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
|
||||
for rec in JAX_SPECIAL_FUNCTION_RECORDS))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
|
||||
test_autodiff, nondiff_argnums):
|
||||
if (jtu.device_under_test() == "cpu" and
|
||||
@ -592,7 +602,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
return
|
||||
S_expected = np.linalg.svd(A, compute_uv=False)
|
||||
U, S, V = jax._src.scipy.eigh.svd(A)
|
||||
recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST)
|
||||
recon = jnp.dot((U * jnp.expand_dims(S, 0)), V,
|
||||
precision=lax.Precision.HIGHEST)
|
||||
eps = jnp.finfo(dtype).eps
|
||||
eps = eps * jnp.linalg.norm(A) * 10
|
||||
self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
|
||||
|
@ -59,6 +59,15 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
|
||||
|
||||
class NdimageTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
|
||||
super().tearDown()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
|
@ -66,6 +66,15 @@ def zakharovFromIndices(x, ii):
|
||||
|
||||
class TestBFGS(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
|
||||
super().tearDown()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
|
||||
"maxiter": maxiter, "func_and_init": func_and_init}
|
||||
@ -142,6 +151,15 @@ class TestBFGS(jtu.JaxTestCase):
|
||||
|
||||
class TestLBFGS(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
|
||||
super().tearDown()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
|
||||
"maxiter": maxiter, "func_and_init": func_and_init}
|
||||
|
@ -38,6 +38,15 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.stats implementations"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
|
||||
super().tearDown()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
|
||||
op,
|
||||
|
Loading…
x
Reference in New Issue
Block a user