diff --git a/CHANGELOG.md b/CHANGELOG.md index 491e4174d..29110aca7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,17 @@ Remember to align the itemized text with the first line of an item within a list # jax 0.4.19 +* Changes + * JAX now requires SciPy 1.9 or newer. + +* Bug fixes + * Only process 0 in a multicontroller distributed JAX program will write + persistent compilation cache entries. This fixes write contention if the + cache is placed on a network filesystem such as GCS. + * The version check for cusolver and cufft no longer considers the patch + versions when determining if the installed version of these libraries is at + least as new as the versions against which JAX was built. + # jaxlib 0.4.19 # jax 0.4.18 (Oct 6, 2023) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 425df2d6b..04040c97e 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -373,6 +373,12 @@ def _cache_write(cache_key: str, """Writes the `serialized_computation` and its compilation time to the persistent compilation cache repository. """ + # Only write cache entries from the first process. Otherwise we create + # problems with contention for writes on some filesystems, e.g., GCS. + if backend.process_index() != 0: + logger.debug("Not writing persistent cache entry since process_index != 0") + return + if host_callbacks: logger.info( "Not writing persistent cache entry for '%s' because it uses host "