mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
To avoid the inconsistency between process_index and process_id, replace backend.process_index with distributed.global_state.process_id in Jax compilation _cache_write function.
Testing: new unit test. PiperOrigin-RevId: 607385112
This commit is contained in:
parent
a23ae0d8f3
commit
3708336f8f
@ -501,8 +501,8 @@ def _cache_write(cache_key: str,
|
||||
"""
|
||||
# Only write cache entries from the first process. Otherwise we create
|
||||
# problems with contention for writes on some filesystems, e.g., GCS.
|
||||
if backend.process_index() != 0:
|
||||
logger.debug("Not writing persistent cache entry since process_index != 0")
|
||||
if distributed.global_state.process_id != 0:
|
||||
logger.debug("Not writing persistent cache entry since process_id != 0")
|
||||
return
|
||||
|
||||
if host_callbacks:
|
||||
|
@ -32,6 +32,7 @@ from jax import pmap
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import distributed
|
||||
from jax._src import monitoring
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -427,6 +428,24 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
- previous_counts["/jax/compilation_cache/cache_hits"],
|
||||
1)
|
||||
|
||||
@parameterized.parameters(0, 1)
|
||||
def test_cache_write_with_process_restriction(self, process_id):
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.persistent_cache_min_compile_time_secs(0),
|
||||
config.persistent_cache_min_entry_size_bytes(0),
|
||||
mock.patch.object(distributed.global_state, "process_id", process_id),
|
||||
):
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
jit(lambda x: x + 1)(1)
|
||||
|
||||
files_in_directory = len(os.listdir(tmpdir))
|
||||
if process_id == 0:
|
||||
self.assertEqual(files_in_directory, 1)
|
||||
elif process_id == 1:
|
||||
self.assertEqual(files_in_directory, 0)
|
||||
|
||||
|
||||
@jtu.with_config(
|
||||
jax_enable_compilation_cache=False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user