mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
shape_poly_test: adjust configs via jtu.global_config_context
This commit is contained in:
parent
da87e4470a
commit
f04a2279a5
@ -1011,14 +1011,16 @@ def promote_like_jnp(fun, inexact=False):
|
||||
return wrapper
|
||||
|
||||
@contextmanager
|
||||
def config_context(**kwds):
|
||||
def global_config_context(**kwds):
|
||||
original_config = {}
|
||||
for key, value in kwds.items():
|
||||
original_config[key] = config._read(key)
|
||||
config.update(key, value)
|
||||
yield
|
||||
for key, value in original_config.items():
|
||||
config.update(key, value)
|
||||
try:
|
||||
for key, value in kwds.items():
|
||||
original_config[key] = config._read(key)
|
||||
config.update(key, value)
|
||||
yield
|
||||
finally:
|
||||
for key, value in original_config.items():
|
||||
config.update(key, value)
|
||||
|
||||
|
||||
class NotPresent:
|
||||
@ -1071,7 +1073,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
def setUpClass(cls):
|
||||
cls._compilation_cache_exit_stack = ExitStack()
|
||||
stack = cls._compilation_cache_exit_stack
|
||||
stack.enter_context(config_context(**cls._default_config))
|
||||
stack.enter_context(global_config_context(**cls._default_config))
|
||||
|
||||
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||
stack.enter_context(config.enable_compilation_cache(True))
|
||||
|
@ -2815,17 +2815,8 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("JAX implements eig only on CPU.")
|
||||
|
||||
prev_jax_config_flags = {
|
||||
fname: getattr(jax.config, fname)
|
||||
for fname, fvalue in harness.override_jax_config_flags.items()
|
||||
}
|
||||
try:
|
||||
for fname, fvalue in harness.override_jax_config_flags.items():
|
||||
jax.config.update(fname, fvalue)
|
||||
with jtu.global_config_context(**harness.override_jax_config_flags):
|
||||
harness.run_test(self)
|
||||
finally:
|
||||
for fname, _ in harness.override_jax_config_flags.items():
|
||||
jax.config.update(fname, prev_jax_config_flags[fname])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -3327,14 +3327,8 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
if "random_gamma" in harness.group_name:
|
||||
config_flags = {**config_flags, "jax_debug_key_reuse": False}
|
||||
|
||||
prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
|
||||
try:
|
||||
for fname, fvalue in config_flags.items():
|
||||
jax.config.update(fname, fvalue)
|
||||
with jtu.global_config_context(**config_flags):
|
||||
harness.run_test(self)
|
||||
finally:
|
||||
for fname, _ in config_flags.items():
|
||||
jax.config.update(fname, prev_jax_config_flags[fname])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user