To avoid the inconsistency between process_index and process_id, replace backend.process_index with distributed.global_state.process_id in Jax compilation _cache_write function.

Testing: new unit test.
PiperOrigin-RevId: 607385112
This commit is contained in:
jax authors 2024-02-15 10:47:10 -08:00
parent a23ae0d8f3
commit 3708336f8f
2 changed files with 21 additions and 2 deletions

View File

@ -501,8 +501,8 @@ def _cache_write(cache_key: str,
"""
# Only write cache entries from the first process. Otherwise we create
# problems with contention for writes on some filesystems, e.g., GCS.
if backend.process_index() != 0:
logger.debug("Not writing persistent cache entry since process_index != 0")
if distributed.global_state.process_id != 0:
logger.debug("Not writing persistent cache entry since process_id != 0")
return
if host_callbacks:

View File

@ -32,6 +32,7 @@ from jax import pmap
from jax._src import compilation_cache as cc
from jax._src import compiler
from jax._src import config
from jax._src import distributed
from jax._src import monitoring
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -427,6 +428,24 @@ class CompilationCacheTest(jtu.JaxTestCase):
- previous_counts["/jax/compilation_cache/cache_hits"],
1)
@parameterized.parameters(0, 1)
def test_cache_write_with_process_restriction(self, process_id):
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(0),
config.persistent_cache_min_entry_size_bytes(0),
mock.patch.object(distributed.global_state, "process_id", process_id),
):
cc.set_cache_dir(tmpdir)
jit(lambda x: x + 1)(1)
files_in_directory = len(os.listdir(tmpdir))
if process_id == 0:
self.assertEqual(files_in_directory, 1)
elif process_id == 1:
self.assertEqual(files_in_directory, 0)
@jtu.with_config(
jax_enable_compilation_cache=False,