Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.

PiperOrigin-RevId: 642359226
This commit is contained in:
Yash Katariya 2024-06-11 12:46:11 -07:00 committed by jax authors
parent d20b9e324f
commit 6c34a56b87
9 changed files with 56 additions and 32 deletions

View File

@ -2964,9 +2964,9 @@ def clear_caches():
This doesn't clear the persistent cache; to disable it (e.g. for benchmarks),
set the jax_enable_compilation_cache config option to False.
"""
# Clear all lu.cache and util.weakref_lru_cache instances (used for staging
# and Python-dispatch compiled executable caches).
lu.clear_all_caches()
# Clear all lu.cache, util.cache and util.weakref_lru_cache instances
# (used for staging and Python-dispatch compiled executable caches).
util.clear_all_caches()
util.clear_all_weakref_lru_caches()
# Clear all C++ compiled executable caches for pjit

View File

@ -45,7 +45,7 @@ from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
import numpy as np
@ -120,7 +120,7 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
return jnp_value
@functools.lru_cache(maxsize=4096)
@cache(max_size=4096, trace_context_in_key=False)
def _cached_index_calc(s, shape):
map_ = s.addressable_devices_indices_map(shape)
seen_h_indices = set()
@ -133,7 +133,7 @@ def _cached_index_calc(s, shape):
return l
@functools.lru_cache(maxsize=4096)
@cache(max_size=4096, trace_context_in_key=False)
def _process_has_full_value_in_mcjax(s, shape):
# Return False for single host as a fast path.
if xla_bridge.process_count() == 1:
@ -1081,7 +1081,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
@functools.lru_cache(maxsize=4096)
@cache(max_size=4096, trace_context_in_key=False)
def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
src_indices = src_sharding.addressable_devices_indices_map(shape).values()
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()

View File

@ -71,7 +71,7 @@ from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import tree_map
from jax._src.util import curry
from jax._src.util import curry, cache_clearing_funs
traceback_util.register_exclusion(__file__)
@ -359,17 +359,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
memoized_fun.cache_clear = fun_caches.clear # type: ignore
memoized_fun.evict_function = _evict_function # type: ignore
cache_clearing_funs.add(memoized_fun.cache_clear)
return memoized_fun
cache_clearing_funs = weakref.WeakSet() # type: ignore
def clear_all_caches():
global cache_clearing_funs
for clear in cache_clearing_funs:
clear()
@partial(partial, tree_map)
def _copy_main_traces(x):

View File

@ -91,7 +91,7 @@ class ResourceEnv(NamedTuple):
return f"ResourceEnv(mesh=Mesh({mesh_repr}), {self.loops!r})"
@functools.lru_cache(maxsize=128)
@util.cache(max_size=128, trace_context_in_key=False)
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
if global_mesh.empty:
return global_mesh

View File

@ -17,7 +17,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Sequence, Iterable
import dataclasses
from functools import partial, lru_cache
from functools import partial
import inspect
import itertools as it
import logging
@ -1013,7 +1013,7 @@ class PytreeLeaf:
def __repr__(self): return "pytree leaf"
@lru_cache(maxsize=4096)
@util.cache(max_size=4096, trace_context_in_key=False)
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
in_layouts_treedef, in_layouts_leaves,
in_avals, in_tree, debug_info,
@ -1211,7 +1211,7 @@ def _create_pjit_jaxpr(fun, in_type, attr_data, debug_info, out_paths, ignored_i
return closed_jaxpr, final_consts, global_out_avals, attrs_tracked
@lru_cache(maxsize=4096)
@util.cache(max_size=4096, trace_context_in_key=False)
def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set):

View File

@ -17,7 +17,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
import functools
from jax._src.util import safe_zip, use_cpp_class
from jax._src.util import safe_zip, use_cpp_class, cache
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.op_shardings import (
@ -30,7 +30,7 @@ Index = tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@functools.lru_cache(maxsize=4096)
@cache(max_size=4096, trace_context_in_key=False)
def _addressable_devices_indices_map(
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
global_map = sharding.devices_indices_map(global_shape)
@ -42,7 +42,7 @@ def _addressable_devices_indices_map(
return {d: ind for d, ind in global_map.items()
if d.process_index == d.client.process_index()}
@functools.lru_cache(maxsize=4096)
@cache(max_size=4096, trace_context_in_key=False)
def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]:
s.shard_shape(global_shape) # raises a good error message
hlo_sharding = s._to_xla_hlo_sharding(len(global_shape))
@ -51,7 +51,7 @@ def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]
return dict(safe_zip(s._device_assignment, indices))
@functools.lru_cache(maxsize=4096)
@cache(max_size=4096, trace_context_in_key=False)
def _common_shard_shape(self, global_shape: Shape) -> Shape:
hlo_sharding = self._to_xla_hlo_sharding(len(global_shape))
if is_op_sharding_replicated(hlo_sharding):

View File

@ -52,7 +52,7 @@ class TransferToMemoryKind:
memory_kind: str
@functools.lru_cache
@util.cache(max_size=128, trace_context_in_key=False)
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
try:
for p in parsed_pspec:
@ -75,7 +75,7 @@ def hashed_index(x) -> int:
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
@functools.lru_cache(maxsize=4096)
@util.cache(max_size=4096, trace_context_in_key=False)
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
try:
device_indices_map_fn = sharding.devices_indices_map
@ -95,7 +95,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out
@functools.lru_cache(maxsize=4096)
@util.cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
mesh_shape = self.mesh.shape
@ -297,7 +297,7 @@ class NamedSharding(sharding.Sharding):
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
@functools.lru_cache
@util.cache(max_size=128, trace_context_in_key=False)
def get_replicated_hlo_sharding():
return xc.HloSharding.replicate()
@ -373,7 +373,7 @@ class SingleDeviceSharding(sharding.Sharding):
return True
@functools.lru_cache(maxsize=4096)
@util.cache(max_size=4096, trace_context_in_key=False)
def pmap_sharding_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message
@ -561,7 +561,7 @@ def _op_sharding_to_pos_sharding(
return p
@functools.lru_cache(maxsize=4096)
@util.cache(max_size=4096, trace_context_in_key=False)
def _positional_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
if self.shape == (1,) * self.ndim:

View File

@ -16,6 +16,7 @@ from __future__ import annotations
import abc
from collections.abc import Iterable, Iterator, Sequence
import dataclasses
import functools
from functools import partial
import itertools as it
@ -285,7 +286,18 @@ def split_merge(predicate, xs):
return lhs, rhs, merge
def cache(max_size=4096):
@dataclasses.dataclass(frozen=True)
class _IgnoreKey:
def __hash__(self):
return hash(self.__class__)
def __eq__(self, other):
return isinstance(other, _IgnoreKey)
def cache(max_size=4096, trace_context_in_key=True):
def wrap(f):
@functools.lru_cache(max_size)
def cached(_, *args, **kwargs):
@ -295,14 +307,24 @@ def cache(max_size=4096):
def wrapper(*args, **kwargs):
if config.check_tracer_leaks.value:
return f(*args, **kwargs)
else:
elif trace_context_in_key:
return cached(config.trace_context(), *args, **kwargs)
else:
return cached(_IgnoreKey(), *args, **kwargs)
wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
cache_clearing_funs.add(wrapper.cache_clear)
return wrapper
return wrap
cache_clearing_funs = weakref.WeakSet() # type: ignore
def clear_all_caches():
global cache_clearing_funs
for clear in cache_clearing_funs:
clear()
memoize = cache(max_size=None)
def weakref_lru_cache(call: Callable, maxsize=2048):

View File

@ -31,6 +31,7 @@ from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
pmap_sharding_devices_indices_map,
NamedSharding, GSPMDSharding,
@ -831,6 +832,15 @@ class ShardingTest(jtu.JaxTestCase):
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
[0, 2, 4, 6, 1, 3, 5, 7])
def test_util_clear_cache(self):
mesh = jtu.create_global_mesh((1,), ('x',))
s = NamedSharding(mesh, P())
s.devices_indices_map((8,))
jax.clear_caches()
s.devices_indices_map((8,))
c = common_devices_indices_map.cache_info()
self.assertEqual(c.currsize, 1)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),