Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01

PiperOrigin-RevId: 668969133
This commit is contained in:
Yash Katariya 2024-08-29 09:42:35 -07:00 committed by jax authors
parent 7dd9adba05
commit b615266175
5 changed files with 56 additions and 151 deletions

View File

@ -2970,8 +2970,7 @@ def clear_backends():
pjit._infer_params_cached.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
pjit._cpp_pjit_cache.clear()
xc._xla.PjitFunctionCache.clear_all()
@atexit.register
@ -2999,8 +2998,7 @@ def clear_caches():
util.clear_all_weakref_lru_caches()
# Clear all C++ compiled executable caches for pjit
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
pjit._cpp_pjit_cache.clear()
pjit._infer_params_cached.cache_clear()
xc._xla.PjitFunctionCache.clear_all()

View File

@ -22,7 +22,6 @@ from collections import namedtuple
from collections.abc import Callable, Sequence, Iterable, Iterator
import dataclasses
from functools import partial, lru_cache, cached_property
import functools
import itertools as it
import logging
import math
@ -90,7 +89,6 @@ unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
Index = Union[int, slice, tuple[Union[int, slice], ...]]
PyTreeDef = tree_util.PyTreeDef
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
@ -2907,34 +2905,6 @@ class MeshExecutableFastpathData(NamedTuple):
in_device_local_layouts: Sequence[DeviceLocalLayout | None]
@dataclasses.dataclass(frozen=True, kw_only=True)
class JitGlobalCppCacheKeys:
donate_argnums: tuple[int, ...] | None = None
donate_argnames: tuple[str, ...] | None = None
device: xc.Device | None = None
backend: str | None = None
in_shardings_treedef: PyTreeDef | None = None
in_shardings_leaves: tuple[Any, ...] | None = None
out_shardings_treedef: PyTreeDef | None = None
out_shardings_leaves: tuple[Any, ...] | None = None
in_layouts_treedef: PyTreeDef | None = None
in_layouts_leaves: tuple[Any, ...] | None = None
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
use_resource_env: bool = False
@functools.cached_property
def contains_explicit_attributes(self):
return (self.donate_argnums is not None or
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))
def reflatten_outputs_for_dispatch(out_tree, out_flat):
# We arrive at dispatch having flattened according to the default
# pytree registry, but we want to re-flatten according to our
@ -3048,14 +3018,9 @@ class MeshExecutable(stages.XlaExecutable):
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry
if xla_extension_version >= 283:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [],
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
else:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
if xla_extension_version < 282:
def cc_shard_arg(x, sharding):

View File

@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
@ -166,6 +165,7 @@ class PjitInfo(NamedTuple):
keep_unused: bool
inline: bool
abstracted_axes: Any | None
has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit
# Hash and compare PjitInfo by identity when used as a cache key.
@ -312,39 +312,14 @@ def _cpp_pjit_evict_fn(self):
# The entries are doubled here from the default 4096 because _pjit_call_impl
# also has a cpp dispatch path and that would double the number of entries in
# the global shared cache.
# This cache is only used for jit's with only fun. For example: jax.jit(f)
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
# This cache is used for jit where extra arguments are defined other than the
# fun. For example: jax.jit(f, donate_argnums=...) OR
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
# capacity might get full very fast because of all the jitted function in JAX
# which might evict train_step for example.
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192)
if xla_extension_version < 283:
def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache_fun_only
def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))
else:
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only
def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@ -365,35 +340,11 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
if xla_extension_version >= 283:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
jit_info.device, jit_info.backend)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, jit_info.donate_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))
cpp_pjit_f = xc._xla.pjit(
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(jit_info.has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
@ -401,6 +352,17 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
return cpp_pjitted_f
def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))
def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
@ -484,6 +446,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings_leaves, out_shardings_leaves, in_layouts_leaves,
out_layouts_leaves, device, backend)
return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
@ -501,6 +467,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
has_explicit_sharding=has_explicit_sharding,
use_resource_env=use_resource_env)
@ -1766,27 +1733,13 @@ def _pjit_call_impl(*args, jaxpr,
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
if xla_extension_version >= 283:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
device=None, backend=None,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
use_resource_env=resource_env is not None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -90,17 +90,19 @@ def sync_global_devices(name: str):
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
# Identity function is at the top level so that `process_allgather` doesn't
# recompile on every invocation.
def _identity_fn(x):
return x
@lru_cache(maxsize=128)
def _jitted_identity_fn(sharding):
return jax.jit(_identity_fn, out_shardings=sharding)
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment)
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
out = _jitted_identity_fn(reps)(inp)
else:
# All inputs here will be fully addressable.
if jax.process_count() == 1:
@ -123,8 +125,7 @@ def _handle_array_process_allgather(inp, tiled):
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
out = jax.jit(_identity_fn,
out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr)
out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr)
return np.asarray(out.addressable_data(0))

View File

@ -653,16 +653,18 @@ class PJitTest(jtu.BufferDonationTestCase):
@jtu.with_mesh([('x', 2), ('y', 1)])
def testAutodiffCache(self):
f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None)
f = pjit(
lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None
)
x = jnp.arange(16, dtype=jnp.float32)
jax.grad(f)(x) # Warm up the cache.
with jtu.count_pjit_cpp_cache_miss() as count:
jax.grad(f)(x)
if xla_extension_version >= 283:
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
else:
self.assertEqual(count[0], 2)
before = pjit_lib._pjit_lower_cached.cache_info()
jax.grad(f)(x)
after = pjit_lib._pjit_lower_cached.cache_info()
# One hit for the forward pass, one hit for backward.
self.assertEqual(after.hits, before.hits + 2)
self.assertEqual(after.misses, before.misses)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self):
@ -4536,20 +4538,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
' match the mesh shape of the target sharding.*'):
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))
@unittest.skipIf(xla_extension_version < 283,
"Requires xla_extension_version >= 283")
def test_global_jit_cpp_cache_hit_out_shardings(self):
mesh = jtu.create_global_mesh((2,), 'x')
s = NamedSharding(mesh, P('x'))
def f(x):
return x * 2
with jtu.count_pjit_cpp_cache_miss() as count:
jax.jit(f, out_shardings=s)(np.arange(8))
jax.jit(f, out_shardings=s)(np.arange(8))
self.assertEqual(count[0], 1)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")