mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make sure that tests don't change the state of the compilation cache
If it was initialized before the test, it should stay so after. And the other way around too. PiperOrigin-RevId: 726899671
This commit is contained in:
parent
49ad24152c
commit
5ab8c5a8a4
@ -126,6 +126,7 @@ py_library(
|
||||
":internal",
|
||||
] + jax_test_util_visibility,
|
||||
deps = [
|
||||
":compilation_cache_internal",
|
||||
":jax",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
@ -42,6 +42,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import compilation_cache
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
@ -60,7 +61,6 @@ from jax._src.public_test_util import ( # noqa: F401
|
||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance)
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental.compilation_cache import compilation_cache
|
||||
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
@ -1265,17 +1265,23 @@ class NotPresent:
|
||||
|
||||
@contextmanager
|
||||
def assert_global_configs_unchanged():
|
||||
starting_cache = compilation_cache._cache
|
||||
starting_config = jax.config.values.copy()
|
||||
yield
|
||||
ending_config = jax.config.values
|
||||
ending_cache = compilation_cache._cache
|
||||
|
||||
if starting_config == ending_config:
|
||||
return
|
||||
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
|
||||
for k in (starting_config.keys() | ending_config.keys())
|
||||
if (k not in starting_config or k not in ending_config
|
||||
or starting_config[k] != ending_config[k])}
|
||||
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
|
||||
if starting_config != ending_config:
|
||||
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
|
||||
for k in (starting_config.keys() | ending_config.keys())
|
||||
if (k not in starting_config or k not in ending_config
|
||||
or starting_config[k] != ending_config[k])}
|
||||
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
|
||||
if starting_cache is not ending_cache:
|
||||
raise AssertionError(
|
||||
f"Test changed the compilation cache object: before test it was "
|
||||
f"{starting_cache}, now it is {ending_cache}"
|
||||
)
|
||||
|
||||
|
||||
class JaxTestCase(parameterized.TestCase):
|
||||
|
@ -350,8 +350,8 @@ class PgleTest(jtu.JaxTestCase):
|
||||
# Test pass fdo_profile as compiler_options API works.
|
||||
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
|
||||
|
||||
|
||||
def testPersistentCachePopulatedWithAutoPgle(self):
|
||||
self.skipTest('Test does not cleanly reset the compilation cache')
|
||||
its = 50
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user