Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.

This allows us to get more cache hits globally. For example:

Before:

```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
```

After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit
```

Also, we can remove the hack (which I didn't like) in multihost_utils.py.

PiperOrigin-RevId: 665574475
This commit is contained in:
Yash Katariya 2024-08-20 16:18:21 -07:00 committed by jax authors
parent 7cd10d8854
commit 82c9da020a
5 changed files with 148 additions and 56 deletions

View File

@ -2965,7 +2965,8 @@ def clear_backends():
pjit._infer_params_cached.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
pjit._cpp_pjit_cache.clear()
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
xc._xla.PjitFunctionCache.clear_all()
@atexit.register
@ -2993,7 +2994,8 @@ def clear_caches():
util.clear_all_weakref_lru_caches()
# Clear all C++ compiled executable caches for pjit
pjit._cpp_pjit_cache.clear()
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
pjit._infer_params_cached.cache_clear()
xc._xla.PjitFunctionCache.clear_all()

View File

@ -22,6 +22,7 @@ from collections import namedtuple
from collections.abc import Callable, Sequence, Iterable, Iterator
import dataclasses
from functools import partial, lru_cache, cached_property
import functools
import itertools as it
import logging
import math
@ -89,6 +90,7 @@ unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
Index = Union[int, slice, tuple[Union[int, slice], ...]]
PyTreeDef = tree_util.PyTreeDef
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
@ -2922,6 +2924,33 @@ class MeshExecutableFastpathData(NamedTuple):
in_device_local_layouts: Sequence[DeviceLocalLayout | None]
@dataclasses.dataclass(frozen=True)
class JitGlobalCppCacheKeys:
donate_argnums: tuple[int, ...] | None = None
donate_argnames: tuple[str, ...] | None = None
device: xc.Device | None = None
backend: str | None = None
in_shardings_treedef: PyTreeDef | None = None
in_shardings_leaves: tuple[Any, ...] | None = None
out_shardings_treedef: PyTreeDef | None = None
out_shardings_leaves: tuple[Any, ...] | None = None
in_layouts_treedef: PyTreeDef | None = None
in_layouts_leaves: tuple[Any, ...] | None = None
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
@functools.cached_property
def contains_explicit_attributes(self):
return (self.donate_argnums is not None or
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))
def reflatten_outputs_for_dispatch(out_tree, out_flat):
# We arrive at dispatch having flattened according to the default
# pytree registry, but we want to re-flatten according to our
@ -3037,9 +3066,14 @@ class MeshExecutable(stages.XlaExecutable):
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
if xla_extension_version >= 283:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [],
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
else:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
if xla_extension_version < 282:
def cc_shard_arg(x, sharding):

View File

@ -63,6 +63,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
@ -165,7 +166,6 @@ class PjitInfo(NamedTuple):
keep_unused: bool
inline: bool
abstracted_axes: Any | None
has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit
# Hash and compare PjitInfo by identity when used as a cache key.
@ -314,14 +314,39 @@ def _cpp_pjit_evict_fn(self):
# The entries are doubled here from the default 4096 because _pjit_call_impl
# also has a cpp dispatch path and that would double the number of entries in
# the global shared cache.
_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192)
# This cache is only used for jit's with only fun. For example: jax.jit(f)
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
# This cache is used for jit where extra arguments are defined other than the
# fun. For example: jax.jit(f, donate_argnums=...) OR
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
# capacity might get full very fast because of all the jitted function in JAX
# which might evict train_step for example.
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache
if xla_extension_version < 283:
def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache_fun_only
def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))
else:
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@ -339,11 +364,34 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(jit_info.has_explicit_sharding))
if xla_extension_version >= 283:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
jit_info.device, jit_info.backend)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, jit_info.donate_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
@ -351,17 +399,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
return cpp_pjitted_f
def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))
def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
@ -445,10 +482,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings_leaves, out_shardings_leaves, in_layouts_leaves,
out_layouts_leaves, device, backend)
return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
@ -466,7 +499,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
has_explicit_sharding=has_explicit_sharding,
use_resource_env=use_resource_env)
@ -1724,13 +1756,26 @@ def _pjit_call_impl(*args, jaxpr,
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
if xla_extension_version >= 283:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
device=None, backend=None,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -90,19 +90,17 @@ def sync_global_devices(name: str):
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
# Identity function is at the top level so that `process_allgather` doesn't
# recompile on every invocation.
def _identity_fn(x):
return x
@lru_cache(maxsize=128)
def _jitted_identity_fn(sharding):
return jax.jit(_identity_fn, out_shardings=sharding)
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment)
out = _jitted_identity_fn(reps)(inp)
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
else:
# All inputs here will be fully addressable.
if jax.process_count() == 1:
@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled):
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr)
out = jax.jit(_identity_fn,
out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr)
return np.asarray(out.addressable_data(0))

View File

@ -635,18 +635,16 @@ class PJitTest(jtu.BufferDonationTestCase):
@jtu.with_mesh([('x', 2), ('y', 1)])
def testAutodiffCache(self):
f = pjit(
lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None
)
f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None)
x = jnp.arange(16, dtype=jnp.float32)
jax.grad(f)(x) # Warm up the cache.
before = pjit_lib._pjit_lower_cached.cache_info()
jax.grad(f)(x)
after = pjit_lib._pjit_lower_cached.cache_info()
# One hit for the forward pass, one hit for backward.
self.assertEqual(after.hits, before.hits + 2)
self.assertEqual(after.misses, before.misses)
jax.grad(f)(x) # Warm up the cache.
with jtu.count_pjit_cpp_cache_miss() as count:
jax.grad(f)(x)
if xla_extension_version >= 283:
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
else:
self.assertEqual(count[0], 2)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self):
@ -4467,6 +4465,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
' match the mesh shape of the target sharding.*'):
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))
@unittest.skipIf(xla_extension_version < 283,
"Requires xla_extension_version >= 283")
def test_global_jit_cpp_cache_hit_out_shardings(self):
mesh = jtu.create_global_mesh((2,), 'x')
s = NamedSharding(mesh, P('x'))
def f(x):
return x * 2
with jtu.count_pjit_cpp_cache_miss() as count:
jax.jit(f, out_shardings=s)(np.arange(8))
jax.jit(f, out_shardings=s)(np.arange(8))
self.assertEqual(count[0], 1)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")