mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01
PiperOrigin-RevId: 668969133
This commit is contained in:
parent
7dd9adba05
commit
b615266175
@ -2970,8 +2970,7 @@ def clear_backends():
|
||||
pjit._infer_params_cached.cache_clear()
|
||||
pjit._pjit_lower_cached.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
|
||||
pjit._cpp_pjit_cache_fun_only.clear()
|
||||
pjit._cpp_pjit_cache_explicit_attributes.clear()
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
||||
@atexit.register
|
||||
@ -2999,8 +2998,7 @@ def clear_caches():
|
||||
util.clear_all_weakref_lru_caches()
|
||||
|
||||
# Clear all C++ compiled executable caches for pjit
|
||||
pjit._cpp_pjit_cache_fun_only.clear()
|
||||
pjit._cpp_pjit_cache_explicit_attributes.clear()
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
pjit._infer_params_cached.cache_clear()
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
||||
|
@ -22,7 +22,6 @@ from collections import namedtuple
|
||||
from collections.abc import Callable, Sequence, Iterable, Iterator
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache, cached_property
|
||||
import functools
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
@ -90,7 +89,6 @@ unsafe_map, map = map, safe_map # type: ignore
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Index = Union[int, slice, tuple[Union[int, slice], ...]]
|
||||
PyTreeDef = tree_util.PyTreeDef
|
||||
|
||||
NoSharding = sharding_specs.NoSharding
|
||||
Chunked = sharding_specs.Chunked
|
||||
@ -2907,34 +2905,6 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
in_device_local_layouts: Sequence[DeviceLocalLayout | None]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class JitGlobalCppCacheKeys:
|
||||
donate_argnums: tuple[int, ...] | None = None
|
||||
donate_argnames: tuple[str, ...] | None = None
|
||||
device: xc.Device | None = None
|
||||
backend: str | None = None
|
||||
in_shardings_treedef: PyTreeDef | None = None
|
||||
in_shardings_leaves: tuple[Any, ...] | None = None
|
||||
out_shardings_treedef: PyTreeDef | None = None
|
||||
out_shardings_leaves: tuple[Any, ...] | None = None
|
||||
in_layouts_treedef: PyTreeDef | None = None
|
||||
in_layouts_leaves: tuple[Any, ...] | None = None
|
||||
out_layouts_treedef: PyTreeDef | None = None
|
||||
out_layouts_leaves: tuple[Any, ...] | None = None
|
||||
use_resource_env: bool = False
|
||||
|
||||
@functools.cached_property
|
||||
def contains_explicit_attributes(self):
|
||||
return (self.donate_argnums is not None or
|
||||
self.donate_argnames is not None or
|
||||
self.device is not None or
|
||||
self.backend is not None or
|
||||
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
|
||||
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
|
||||
any(i is not None for i in self.in_layouts_leaves) or
|
||||
any(o is not None for o in self.out_layouts_leaves))
|
||||
|
||||
|
||||
def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
||||
# We arrive at dispatch having flattened according to the default
|
||||
# pytree registry, but we want to re-flatten according to our
|
||||
@ -3048,14 +3018,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data, False # Do not remove cache entry
|
||||
|
||||
if xla_extension_version >= 283:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [],
|
||||
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
|
||||
else:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, cc_shard_arg)
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, cc_shard_arg)
|
||||
|
||||
if xla_extension_version < 282:
|
||||
def cc_shard_arg(x, sharding):
|
||||
|
117
jax/_src/pjit.py
117
jax/_src/pjit.py
@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import sharding
|
||||
from jax._src.mesh import AbstractMesh
|
||||
from jax._src.sharding_impls import (
|
||||
@ -166,6 +165,7 @@ class PjitInfo(NamedTuple):
|
||||
keep_unused: bool
|
||||
inline: bool
|
||||
abstracted_axes: Any | None
|
||||
has_explicit_sharding: bool
|
||||
use_resource_env: bool # False for jit, True for pjit
|
||||
|
||||
# Hash and compare PjitInfo by identity when used as a cache key.
|
||||
@ -312,39 +312,14 @@ def _cpp_pjit_evict_fn(self):
|
||||
# The entries are doubled here from the default 4096 because _pjit_call_impl
|
||||
# also has a cpp dispatch path and that would double the number of entries in
|
||||
# the global shared cache.
|
||||
# This cache is only used for jit's with only fun. For example: jax.jit(f)
|
||||
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
|
||||
|
||||
# This cache is used for jit where extra arguments are defined other than the
|
||||
# fun. For example: jax.jit(f, donate_argnums=...) OR
|
||||
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
|
||||
# capacity might get full very fast because of all the jitted function in JAX
|
||||
# which might evict train_step for example.
|
||||
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
|
||||
_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192)
|
||||
|
||||
|
||||
if xla_extension_version < 283:
|
||||
def _get_cpp_global_cache(pjit_has_explicit_sharding):
|
||||
if pjit_has_explicit_sharding:
|
||||
return xc._xla.PjitFunctionCache()
|
||||
else:
|
||||
return _cpp_pjit_cache_fun_only
|
||||
|
||||
def _pjit_explicit_sharding_and_layout(
|
||||
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
|
||||
device, backend) -> bool:
|
||||
return (device is not None or
|
||||
backend is not None or
|
||||
any(not is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not is_unspecified(o) for o in out_shardings_flat) or
|
||||
any(i is not None for i in in_layouts_flat) or
|
||||
any(o is not None for o in out_layouts_flat))
|
||||
else:
|
||||
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
|
||||
if contains_explicit_attributes:
|
||||
return _cpp_pjit_cache_explicit_attributes
|
||||
else:
|
||||
return _cpp_pjit_cache_fun_only
|
||||
def _get_cpp_global_cache(pjit_has_explicit_sharding):
|
||||
if pjit_has_explicit_sharding:
|
||||
return xc._xla.PjitFunctionCache()
|
||||
else:
|
||||
return _cpp_pjit_cache
|
||||
|
||||
|
||||
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
@ -365,35 +340,11 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
|
||||
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
|
||||
if xla_extension_version >= 283:
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=jit_info.donate_argnums,
|
||||
donate_argnames=jit_info.donate_argnames,
|
||||
device=jit_info.device, backend=jit_info.backend,
|
||||
in_shardings_treedef=jit_info.in_shardings_treedef,
|
||||
in_shardings_leaves=jit_info.in_shardings_leaves,
|
||||
out_shardings_treedef=jit_info.out_shardings_treedef,
|
||||
out_shardings_leaves=jit_info.out_shardings_leaves,
|
||||
in_layouts_treedef=jit_info.in_layouts_treedef,
|
||||
in_layouts_leaves=jit_info.in_layouts_leaves,
|
||||
out_layouts_treedef=jit_info.out_layouts_treedef,
|
||||
out_layouts_leaves=jit_info.out_layouts_leaves,
|
||||
use_resource_env=jit_info.use_resource_env)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
|
||||
pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
|
||||
else:
|
||||
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
||||
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
|
||||
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
|
||||
jit_info.device, jit_info.backend)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
jit_info.static_argnames, jit_info.donate_argnums,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
@ -401,6 +352,17 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
return cpp_pjitted_f
|
||||
|
||||
|
||||
def _pjit_explicit_sharding_and_layout(
|
||||
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
|
||||
device, backend) -> bool:
|
||||
return (device is not None or
|
||||
backend is not None or
|
||||
any(not is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not is_unspecified(o) for o in out_shardings_flat) or
|
||||
any(i is not None for i in in_layouts_flat) or
|
||||
any(o is not None for o in out_layouts_flat))
|
||||
|
||||
|
||||
def _split_layout_and_sharding(entries):
|
||||
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
|
||||
layouts, shardings = [], []
|
||||
@ -484,6 +446,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames)
|
||||
|
||||
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
||||
in_shardings_leaves, out_shardings_leaves, in_layouts_leaves,
|
||||
out_layouts_leaves, device, backend)
|
||||
|
||||
return PjitInfo(
|
||||
fun_sourceinfo=fun_sourceinfo,
|
||||
fun_signature=fun_signature,
|
||||
@ -501,6 +467,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
abstracted_axes=abstracted_axes,
|
||||
has_explicit_sharding=has_explicit_sharding,
|
||||
use_resource_env=use_resource_env)
|
||||
|
||||
|
||||
@ -1766,27 +1733,13 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
f = _get_jaxpr_as_fun(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline)
|
||||
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
if xla_extension_version >= 283:
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=donated_argnums, donate_argnames=None,
|
||||
device=None, backend=None,
|
||||
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
|
||||
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
|
||||
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
|
||||
use_resource_env=resource_env is not None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], cache_key,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
|
||||
else:
|
||||
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
||||
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
|
||||
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
||||
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
|
@ -90,17 +90,19 @@ def sync_global_devices(name: str):
|
||||
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
|
||||
|
||||
|
||||
# Identity function is at the top level so that `process_allgather` doesn't
|
||||
# recompile on every invocation.
|
||||
def _identity_fn(x):
|
||||
return x
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _jitted_identity_fn(sharding):
|
||||
return jax.jit(_identity_fn, out_shardings=sharding)
|
||||
|
||||
|
||||
def _handle_array_process_allgather(inp, tiled):
|
||||
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
|
||||
reps = sharding_impls.GSPMDSharding.get_replicated(
|
||||
inp.sharding._device_assignment)
|
||||
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
|
||||
out = _jitted_identity_fn(reps)(inp)
|
||||
else:
|
||||
# All inputs here will be fully addressable.
|
||||
if jax.process_count() == 1:
|
||||
@ -123,8 +125,7 @@ def _handle_array_process_allgather(inp, tiled):
|
||||
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
|
||||
global_arr = array.make_array_from_single_device_arrays(
|
||||
global_aval.shape, s, bufs)
|
||||
out = jax.jit(_identity_fn,
|
||||
out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr)
|
||||
out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr)
|
||||
|
||||
return np.asarray(out.addressable_data(0))
|
||||
|
||||
|
@ -653,16 +653,18 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testAutodiffCache(self):
|
||||
f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None)
|
||||
f = pjit(
|
||||
lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None
|
||||
)
|
||||
x = jnp.arange(16, dtype=jnp.float32)
|
||||
|
||||
jax.grad(f)(x) # Warm up the cache.
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
jax.grad(f)(x)
|
||||
if xla_extension_version >= 283:
|
||||
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
|
||||
else:
|
||||
self.assertEqual(count[0], 2)
|
||||
before = pjit_lib._pjit_lower_cached.cache_info()
|
||||
jax.grad(f)(x)
|
||||
after = pjit_lib._pjit_lower_cached.cache_info()
|
||||
|
||||
# One hit for the forward pass, one hit for backward.
|
||||
self.assertEqual(after.hits, before.hits + 2)
|
||||
self.assertEqual(after.misses, before.misses)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testEvalJaxpr(self):
|
||||
@ -4536,20 +4538,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
' match the mesh shape of the target sharding.*'):
|
||||
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 283,
|
||||
"Requires xla_extension_version >= 283")
|
||||
def test_global_jit_cpp_cache_hit_out_shardings(self):
|
||||
mesh = jtu.create_global_mesh((2,), 'x')
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
jax.jit(f, out_shardings=s)(np.arange(8))
|
||||
jax.jit(f, out_shardings=s)(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user