Only perform compilation cache writes from process 0.

This avoids problems with contending writes on filesystems such as GCS.

PiperOrigin-RevId: 572032482
This commit is contained in:
Peter Hawkins 2023-10-09 13:54:32 -07:00 committed by jax authors
parent cb51e37008
commit 4611d13c07
2 changed files with 17 additions and 0 deletions

View File

@ -8,6 +8,17 @@ Remember to align the itemized text with the first line of an item within a list
# jax 0.4.19
* Changes
* JAX now requires SciPy 1.9 or newer.
* Bug fixes
* Only process 0 in a multicontroller distributed JAX program will write
persistent compilation cache entries. This fixes write contention if the
cache is placed on a network filesystem such as GCS.
* The version check for cusolver and cufft no longer considers the patch
versions when determining if the installed version of these libraries is at
least as new as the versions against which JAX was built.
# jaxlib 0.4.19
# jax 0.4.18 (Oct 6, 2023)

View File

@ -373,6 +373,12 @@ def _cache_write(cache_key: str,
"""Writes the `serialized_computation` and its compilation time to the
persistent compilation cache repository.
"""
# Only write cache entries from the first process. Otherwise we create
# problems with contention for writes on some filesystems, e.g., GCS.
if backend.process_index() != 0:
logger.debug("Not writing persistent cache entry since process_index != 0")
return
if host_callbacks:
logger.info(
"Not writing persistent cache entry for '%s' because it uses host "