mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Only perform compilation cache writes from process 0.
This avoids problems with contending writes on filesystems such as GCS. PiperOrigin-RevId: 572032482
This commit is contained in:
parent
cb51e37008
commit
4611d13c07
11
CHANGELOG.md
11
CHANGELOG.md
@ -8,6 +8,17 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
# jax 0.4.19
|
||||
|
||||
* Changes
|
||||
* JAX now requires SciPy 1.9 or newer.
|
||||
|
||||
* Bug fixes
|
||||
* Only process 0 in a multicontroller distributed JAX program will write
|
||||
persistent compilation cache entries. This fixes write contention if the
|
||||
cache is placed on a network filesystem such as GCS.
|
||||
* The version check for cusolver and cufft no longer considers the patch
|
||||
versions when determining if the installed version of these libraries is at
|
||||
least as new as the versions against which JAX was built.
|
||||
|
||||
# jaxlib 0.4.19
|
||||
|
||||
# jax 0.4.18 (Oct 6, 2023)
|
||||
|
@ -373,6 +373,12 @@ def _cache_write(cache_key: str,
|
||||
"""Writes the `serialized_computation` and its compilation time to the
|
||||
persistent compilation cache repository.
|
||||
"""
|
||||
# Only write cache entries from the first process. Otherwise we create
|
||||
# problems with contention for writes on some filesystems, e.g., GCS.
|
||||
if backend.process_index() != 0:
|
||||
logger.debug("Not writing persistent cache entry since process_index != 0")
|
||||
return
|
||||
|
||||
if host_callbacks:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it uses host "
|
||||
|
Loading…
x
Reference in New Issue
Block a user