mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[JAX] Add jax.clear_caches, plumb a way to clear pmap caches
fixes #10828 Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 522654093
This commit is contained in:
parent
0f368e4428
commit
26562a4382
@ -79,6 +79,7 @@ from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
|
||||
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||
from jax._src.api import clear_backends as clear_backends
|
||||
from jax._src.api import clear_caches as clear_caches
|
||||
from jax._src.custom_derivatives import closure_convert as closure_convert
|
||||
from jax._src.util import curry as _deprecated_curry
|
||||
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
||||
|
@ -30,6 +30,7 @@ import typing
|
||||
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
|
||||
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
|
||||
overload, cast)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
from contextlib import contextmanager, ExitStack
|
||||
@ -60,11 +61,13 @@ from jax._src.api_util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps)
|
||||
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
|
||||
from jax._src import util
|
||||
|
||||
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -1822,6 +1825,7 @@ def _cpp_pmap(
|
||||
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg)
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
||||
@ -1831,6 +1835,8 @@ def _cpp_pmap(
|
||||
|
||||
return pmap_f
|
||||
|
||||
_pmap_cache_clears = weakref.WeakSet() # type: ignore
|
||||
|
||||
|
||||
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
devices, backend, axis_size, donate_tuple): # noqa: F811
|
||||
@ -2867,3 +2873,21 @@ def live_arrays(platform=None):
|
||||
If platform is None, it is the default backend.
|
||||
"""
|
||||
return xb.get_backend(platform).live_arrays()
|
||||
|
||||
def clear_caches():
|
||||
# Clear all lu.cache and util.weakref_lru_cache instances (used for staging
|
||||
# and Python-dispatch compiled executable caches).
|
||||
lu.clear_all_caches()
|
||||
util.clear_all_weakref_lru_caches()
|
||||
|
||||
# Clear all C++ compiled executable caches for pjit
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
||||
# Clear all C++ compiled executable caches for pmap
|
||||
if xla_extension_version >= 146: # TODO(frostig): remove when ready
|
||||
for fun in _pmap_cache_clears:
|
||||
fun._cache_clear()
|
||||
|
||||
# Clear particular util.cache instances.
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
|
@ -330,8 +330,17 @@ def cache(call: Callable):
|
||||
memoized_fun.cache_clear = fun_caches.clear # type: ignore
|
||||
memoized_fun.evict_function = _evict_function # type: ignore
|
||||
|
||||
cache_clearing_funs.add(fun_caches.clear)
|
||||
|
||||
return memoized_fun
|
||||
|
||||
cache_clearing_funs = weakref.WeakSet() # type: ignore
|
||||
|
||||
def clear_all_caches():
|
||||
global cache_clearing_funs
|
||||
for clear in cache_clearing_funs:
|
||||
clear()
|
||||
|
||||
@partial(partial, tree_map)
|
||||
def _copy_main_traces(x):
|
||||
if isinstance(x, core.MainTrace):
|
||||
|
@ -20,6 +20,7 @@ import operator
|
||||
from typing import (Any, Callable, Generic, Iterable, Iterator, List,
|
||||
Optional, Sequence, Set, Tuple, TypeVar, overload,
|
||||
TYPE_CHECKING, cast)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -266,7 +267,15 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
|
||||
and strong refs to all subsequent operations. In all other respects it should
|
||||
behave similar to `functools.lru_cache`.
|
||||
"""
|
||||
return xc.weakref_lru_cache(config._trace_context, call, maxsize)
|
||||
global _weakref_lru_caches
|
||||
cached_call = xc.weakref_lru_cache(config._trace_context, call, maxsize)
|
||||
_weakref_lru_caches.add(cached_call)
|
||||
return cached_call
|
||||
_weakref_lru_caches = weakref.WeakSet() # type: ignore
|
||||
|
||||
def clear_all_weakref_lru_caches():
|
||||
for cached_call in _weakref_lru_caches:
|
||||
cached_call.cache_clear()
|
||||
|
||||
class Unhashable:
|
||||
__slots__ = ["val"]
|
||||
|
@ -4247,6 +4247,17 @@ class APITest(jtu.JaxTestCase):
|
||||
out = jax.grad(f)(3.0) # doesn't crash
|
||||
self.assertAllClose(out, 1., check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 146,
|
||||
'Test requires xla_extension_version >= 146')
|
||||
def test_cache_clear_pmap(self):
|
||||
@jax.pmap
|
||||
def f(i):
|
||||
return i * 2
|
||||
|
||||
f(np.arange(1, dtype='float32')).block_until_ready()
|
||||
self.assertEqual(f._cache_size, 1)
|
||||
jax.clear_caches()
|
||||
self.assertEqual(f._cache_size, 0)
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user