Document the fact that jax.clear_caches() doesn't affect the persistent cache.

PiperOrigin-RevId: 626019057
This commit is contained in:
jax authors 2024-04-18 06:51:48 -07:00
parent fa66e731e6
commit bb8cf34a31

View File

@ -2996,7 +2996,11 @@ def live_arrays(platform=None):
return xb.get_backend(platform).live_arrays()
def clear_caches():
"""Clear all compilation and staging caches."""
"""Clear all compilation and staging caches.
This doesn't clear the persistent cache; to disable it (e.g. for benchmarks),
set the jax_enable_compilation_cache config option to False.
"""
# Clear all lu.cache and util.weakref_lru_cache instances (used for staging
# and Python-dispatch compiled executable caches).
lu.clear_all_caches()