Make sure that tests don't change the state of the compilation cache

If it was initialized before the test, it should stay so after. And the other
way around too.

PiperOrigin-RevId: 726899671
This commit is contained in:
Adam Paszke 2025-02-14 06:11:12 -08:00 committed by jax authors
parent 49ad24152c
commit 5ab8c5a8a4
3 changed files with 16 additions and 9 deletions

View File

@ -126,6 +126,7 @@ py_library(
":internal",
] + jax_test_util_visibility,
deps = [
":compilation_cache_internal",
":jax",
] + py_deps("absl/testing") + py_deps("numpy"),
)

View File

@ -42,6 +42,7 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax._src import api
from jax._src import compilation_cache
from jax._src import config
from jax._src import core
from jax._src import deprecations
@ -60,7 +61,6 @@ from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance)
from jax._src.util import unzip2
from jax.experimental.compilation_cache import compilation_cache
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
import numpy as np
import numpy.random as npr
@ -1265,17 +1265,23 @@ class NotPresent:
@contextmanager
def assert_global_configs_unchanged():
starting_cache = compilation_cache._cache
starting_config = jax.config.values.copy()
yield
ending_config = jax.config.values
ending_cache = compilation_cache._cache
if starting_config == ending_config:
return
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
for k in (starting_config.keys() | ending_config.keys())
if (k not in starting_config or k not in ending_config
or starting_config[k] != ending_config[k])}
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
if starting_config != ending_config:
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
for k in (starting_config.keys() | ending_config.keys())
if (k not in starting_config or k not in ending_config
or starting_config[k] != ending_config[k])}
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
if starting_cache is not ending_cache:
raise AssertionError(
f"Test changed the compilation cache object: before test it was "
f"{starting_cache}, now it is {ending_cache}"
)
class JaxTestCase(parameterized.TestCase):

View File

@ -350,8 +350,8 @@ class PgleTest(jtu.JaxTestCase):
# Test pass fdo_profile as compiler_options API works.
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
def testPersistentCachePopulatedWithAutoPgle(self):
self.skipTest('Test does not cleanly reset the compilation cache')
its = 50
mesh = jtu.create_mesh((2,), ('x',))