[JAX] Add jax.clear_caches, plumb a way to clear pmap caches

fixes #10828

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
This commit is contained in:
Matthew Johnson 2023-04-07 12:09:26 -07:00 committed by jax authors
parent 0f368e4428
commit 26562a4382
5 changed files with 56 additions and 2 deletions

View File

@ -79,6 +79,7 @@ from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.util import curry as _deprecated_curry
from jax._src.custom_derivatives import custom_gradient as custom_gradient

View File

@ -30,6 +30,7 @@ import typing
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
overload, cast)
import weakref
import numpy as np
from contextlib import contextmanager, ExitStack
@ -60,11 +61,13 @@ from jax._src.api_util import (
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps)
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
from jax._src import util
from jax._src.interpreters import partial_eval as pe
@ -1822,6 +1825,7 @@ def _cpp_pmap(
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg)
_pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f)
@ -1831,6 +1835,8 @@ def _cpp_pmap(
return pmap_f
_pmap_cache_clears = weakref.WeakSet() # type: ignore
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, donate_tuple): # noqa: F811
@ -2867,3 +2873,21 @@ def live_arrays(platform=None):
If platform is None, it is the default backend.
"""
return xb.get_backend(platform).live_arrays()
def clear_caches():
# Clear all lu.cache and util.weakref_lru_cache instances (used for staging
# and Python-dispatch compiled executable caches).
lu.clear_all_caches()
util.clear_all_weakref_lru_caches()
# Clear all C++ compiled executable caches for pjit
pjit._cpp_pjit_cache.clear()
xc._xla.PjitFunctionCache.clear_all()
# Clear all C++ compiled executable caches for pmap
if xla_extension_version >= 146: # TODO(frostig): remove when ready
for fun in _pmap_cache_clears:
fun._cache_clear()
# Clear particular util.cache instances.
dispatch.xla_primitive_callable.cache_clear()

View File

@ -330,8 +330,17 @@ def cache(call: Callable):
memoized_fun.cache_clear = fun_caches.clear # type: ignore
memoized_fun.evict_function = _evict_function # type: ignore
cache_clearing_funs.add(fun_caches.clear)
return memoized_fun
cache_clearing_funs = weakref.WeakSet() # type: ignore
def clear_all_caches():
global cache_clearing_funs
for clear in cache_clearing_funs:
clear()
@partial(partial, tree_map)
def _copy_main_traces(x):
if isinstance(x, core.MainTrace):

View File

@ -20,6 +20,7 @@ import operator
from typing import (Any, Callable, Generic, Iterable, Iterator, List,
Optional, Sequence, Set, Tuple, TypeVar, overload,
TYPE_CHECKING, cast)
import weakref
import numpy as np
@ -266,7 +267,15 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
and strong refs to all subsequent operations. In all other respects it should
behave similar to `functools.lru_cache`.
"""
return xc.weakref_lru_cache(config._trace_context, call, maxsize)
global _weakref_lru_caches
cached_call = xc.weakref_lru_cache(config._trace_context, call, maxsize)
_weakref_lru_caches.add(cached_call)
return cached_call
_weakref_lru_caches = weakref.WeakSet() # type: ignore
def clear_all_weakref_lru_caches():
for cached_call in _weakref_lru_caches:
cached_call.cache_clear()
class Unhashable:
__slots__ = ["val"]

View File

@ -4247,6 +4247,17 @@ class APITest(jtu.JaxTestCase):
out = jax.grad(f)(3.0) # doesn't crash
self.assertAllClose(out, 1., check_dtypes=False)
@unittest.skipIf(xla_extension_version < 146,
'Test requires xla_extension_version >= 146')
def test_cache_clear_pmap(self):
@jax.pmap
def f(i):
return i * 2
f(np.arange(1, dtype='float32')).block_until_ready()
self.assertEqual(f._cache_size, 1)
jax.clear_caches()
self.assertEqual(f._cache_size, 0)
class RematTest(jtu.JaxTestCase):