shape_poly_test: adjust configs via jtu.global_config_context

This commit is contained in:
Jake VanderPlas 2024-06-05 10:45:56 -07:00
parent da87e4470a
commit f04a2279a5
3 changed files with 12 additions and 25 deletions

View File

@ -1011,14 +1011,16 @@ def promote_like_jnp(fun, inexact=False):
return wrapper
@contextmanager
def config_context(**kwds):
def global_config_context(**kwds):
original_config = {}
for key, value in kwds.items():
original_config[key] = config._read(key)
config.update(key, value)
yield
for key, value in original_config.items():
config.update(key, value)
try:
for key, value in kwds.items():
original_config[key] = config._read(key)
config.update(key, value)
yield
finally:
for key, value in original_config.items():
config.update(key, value)
class NotPresent:
@ -1071,7 +1073,7 @@ class JaxTestCase(parameterized.TestCase):
def setUpClass(cls):
cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack
stack.enter_context(config_context(**cls._default_config))
stack.enter_context(global_config_context(**cls._default_config))
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
stack.enter_context(config.enable_compilation_cache(True))

View File

@ -2815,17 +2815,8 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("JAX implements eig only on CPU.")
prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()
}
try:
for fname, fvalue in harness.override_jax_config_flags.items():
jax.config.update(fname, fvalue)
with jtu.global_config_context(**harness.override_jax_config_flags):
harness.run_test(self)
finally:
for fname, _ in harness.override_jax_config_flags.items():
jax.config.update(fname, prev_jax_config_flags[fname])
if __name__ == "__main__":

View File

@ -3327,14 +3327,8 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
if "random_gamma" in harness.group_name:
config_flags = {**config_flags, "jax_debug_key_reuse": False}
prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
try:
for fname, fvalue in config_flags.items():
jax.config.update(fname, fvalue)
with jtu.global_config_context(**config_flags):
harness.run_test(self)
finally:
for fname, _ in config_flags.items():
jax.config.update(fname, prev_jax_config_flags[fname])
if __name__ == "__main__":