mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add util.cache
to jax.clear_caches
and move pjit, sharding, array, etc uses of functools.lru_cache
to util.cache
so that those caches will be cleared if jax.clear_caches
is called.
PiperOrigin-RevId: 642359226
This commit is contained in:
parent
d20b9e324f
commit
6c34a56b87
@ -2964,9 +2964,9 @@ def clear_caches():
|
||||
This doesn't clear the persistent cache; to disable it (e.g. for benchmarks),
|
||||
set the jax_enable_compilation_cache config option to False.
|
||||
"""
|
||||
# Clear all lu.cache and util.weakref_lru_cache instances (used for staging
|
||||
# and Python-dispatch compiled executable caches).
|
||||
lu.clear_all_caches()
|
||||
# Clear all lu.cache, util.cache and util.weakref_lru_cache instances
|
||||
# (used for staging and Python-dispatch compiled executable caches).
|
||||
util.clear_all_caches()
|
||||
util.clear_all_weakref_lru_caches()
|
||||
|
||||
# Clear all C++ compiled executable caches for pjit
|
||||
|
@ -45,7 +45,7 @@ from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
|
||||
from jax._src.typing import ArrayLike, DLDeviceType
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
|
||||
return jnp_value
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def _cached_index_calc(s, shape):
|
||||
map_ = s.addressable_devices_indices_map(shape)
|
||||
seen_h_indices = set()
|
||||
@ -133,7 +133,7 @@ def _cached_index_calc(s, shape):
|
||||
return l
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def _process_has_full_value_in_mcjax(s, shape):
|
||||
# Return False for single host as a fast path.
|
||||
if xla_bridge.process_count() == 1:
|
||||
@ -1081,7 +1081,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
|
||||
src_indices = src_sharding.addressable_devices_indices_map(shape).values()
|
||||
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
|
||||
|
@ -71,7 +71,7 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.tree_util import tree_map
|
||||
from jax._src.util import curry
|
||||
from jax._src.util import curry, cache_clearing_funs
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -359,17 +359,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
|
||||
memoized_fun.cache_clear = fun_caches.clear # type: ignore
|
||||
memoized_fun.evict_function = _evict_function # type: ignore
|
||||
|
||||
cache_clearing_funs.add(memoized_fun.cache_clear)
|
||||
|
||||
return memoized_fun
|
||||
|
||||
cache_clearing_funs = weakref.WeakSet() # type: ignore
|
||||
|
||||
def clear_all_caches():
|
||||
global cache_clearing_funs
|
||||
for clear in cache_clearing_funs:
|
||||
clear()
|
||||
|
||||
@partial(partial, tree_map)
|
||||
def _copy_main_traces(x):
|
||||
|
@ -91,7 +91,7 @@ class ResourceEnv(NamedTuple):
|
||||
return f"ResourceEnv(mesh=Mesh({mesh_repr}), {self.loops!r})"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||
if global_mesh.empty:
|
||||
return global_mesh
|
||||
|
@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence, Iterable
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import logging
|
||||
@ -1013,7 +1013,7 @@ class PytreeLeaf:
|
||||
def __repr__(self): return "pytree leaf"
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
in_layouts_treedef, in_layouts_leaves,
|
||||
in_avals, in_tree, debug_info,
|
||||
@ -1211,7 +1211,7 @@ def _create_pjit_jaxpr(fun, in_type, attr_data, debug_info, out_paths, ignored_i
|
||||
return closed_jaxpr, final_consts, global_out_avals, attrs_tracked
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set):
|
||||
|
@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
import functools
|
||||
|
||||
from jax._src.util import safe_zip, use_cpp_class
|
||||
from jax._src.util import safe_zip, use_cpp_class, cache
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.op_shardings import (
|
||||
@ -30,7 +30,7 @@ Index = tuple[slice, ...]
|
||||
XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def _addressable_devices_indices_map(
|
||||
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
|
||||
global_map = sharding.devices_indices_map(global_shape)
|
||||
@ -42,7 +42,7 @@ def _addressable_devices_indices_map(
|
||||
return {d: ind for d, ind in global_map.items()
|
||||
if d.process_index == d.client.process_index()}
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
s.shard_shape(global_shape) # raises a good error message
|
||||
hlo_sharding = s._to_xla_hlo_sharding(len(global_shape))
|
||||
@ -51,7 +51,7 @@ def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]
|
||||
return dict(safe_zip(s._device_assignment, indices))
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def _common_shard_shape(self, global_shape: Shape) -> Shape:
|
||||
hlo_sharding = self._to_xla_hlo_sharding(len(global_shape))
|
||||
if is_op_sharding_replicated(hlo_sharding):
|
||||
|
@ -52,7 +52,7 @@ class TransferToMemoryKind:
|
||||
memory_kind: str
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
|
||||
try:
|
||||
for p in parsed_pspec:
|
||||
@ -75,7 +75,7 @@ def hashed_index(x) -> int:
|
||||
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
|
||||
try:
|
||||
device_indices_map_fn = sharding.devices_indices_map
|
||||
@ -95,7 +95,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
return out
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def named_sharding_to_xla_hlo_sharding(
|
||||
self, num_dimensions: int) -> xc.HloSharding:
|
||||
mesh_shape = self.mesh.shape
|
||||
@ -297,7 +297,7 @@ class NamedSharding(sharding.Sharding):
|
||||
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def get_replicated_hlo_sharding():
|
||||
return xc.HloSharding.replicate()
|
||||
|
||||
@ -373,7 +373,7 @@ class SingleDeviceSharding(sharding.Sharding):
|
||||
return True
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def pmap_sharding_devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
self.shard_shape(global_shape) # raises a good error message
|
||||
@ -561,7 +561,7 @@ def _op_sharding_to_pos_sharding(
|
||||
return p
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def _positional_sharding_to_xla_hlo_sharding(
|
||||
self, num_dimensions: int) -> xc.HloSharding:
|
||||
if self.shape == (1,) * self.ndim:
|
||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
@ -285,7 +286,18 @@ def split_merge(predicate, xs):
|
||||
|
||||
return lhs, rhs, merge
|
||||
|
||||
def cache(max_size=4096):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _IgnoreKey:
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, _IgnoreKey)
|
||||
|
||||
|
||||
def cache(max_size=4096, trace_context_in_key=True):
|
||||
def wrap(f):
|
||||
@functools.lru_cache(max_size)
|
||||
def cached(_, *args, **kwargs):
|
||||
@ -295,14 +307,24 @@ def cache(max_size=4096):
|
||||
def wrapper(*args, **kwargs):
|
||||
if config.check_tracer_leaks.value:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
elif trace_context_in_key:
|
||||
return cached(config.trace_context(), *args, **kwargs)
|
||||
else:
|
||||
return cached(_IgnoreKey(), *args, **kwargs)
|
||||
|
||||
wrapper.cache_clear = cached.cache_clear
|
||||
wrapper.cache_info = cached.cache_info
|
||||
cache_clearing_funs.add(wrapper.cache_clear)
|
||||
return wrapper
|
||||
return wrap
|
||||
|
||||
cache_clearing_funs = weakref.WeakSet() # type: ignore
|
||||
|
||||
def clear_all_caches():
|
||||
global cache_clearing_funs
|
||||
for clear in cache_clearing_funs:
|
||||
clear()
|
||||
|
||||
memoize = cache(max_size=None)
|
||||
|
||||
def weakref_lru_cache(call: Callable, maxsize=2048):
|
||||
|
@ -31,6 +31,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.sharding import common_devices_indices_map
|
||||
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
|
||||
pmap_sharding_devices_indices_map,
|
||||
NamedSharding, GSPMDSharding,
|
||||
@ -831,6 +832,15 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
|
||||
[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
|
||||
def test_util_clear_cache(self):
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
s = NamedSharding(mesh, P())
|
||||
s.devices_indices_map((8,))
|
||||
jax.clear_caches()
|
||||
s.devices_indices_map((8,))
|
||||
c = common_devices_indices_map.cache_info()
|
||||
self.assertEqual(c.currsize, 1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y")),
|
||||
("mesh_x", P("x")),
|
||||
|
Loading…
x
Reference in New Issue
Block a user