[JAX] Add caching to pjit._infer_params.

When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
This commit is contained in:
Peter Hawkins 2024-06-21 13:52:19 -07:00 committed by jax authors
parent a730f6bfd3
commit 9e30079dba
6 changed files with 175 additions and 36 deletions

View File

@ -905,5 +905,23 @@ def benchmark_lorentz63_cache_hits(state):
jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x)
@google_benchmark.register
def jit_add_chain(state):
SIZE = 100
@jax.jit
def g(x, y):
return lax.add(x, y)
x = jax.random.normal(jax.random.PRNGKey(0), (2, 2))
while state:
@jax.jit
def f(x):
for i in range(SIZE):
x = g(x, x)
return x
f(x).block_until_ready()
if __name__ == "__main__":
google_benchmark.main()

View File

@ -2945,6 +2945,7 @@ def clear_backends():
xb.local_devices.cache_clear()
xb.process_count.cache_clear()
dispatch.xla_primitive_callable.cache_clear()
pjit._infer_params_cached.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
pjit._cpp_pjit_cache.clear()
@ -2970,6 +2971,7 @@ def clear_caches():
# Clear all C++ compiled executable caches for pjit
pjit._cpp_pjit_cache.clear()
pjit._infer_params_cached.cache_clear()
xc._xla.PjitFunctionCache.clear_all()
# Clear all C++ compiled executable caches for pmap

View File

@ -1739,7 +1739,10 @@ class DynamicJaxprTracer(core.Tracer):
frame = self._trace.frame
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return core.raise_to_shaped(x.aval)
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()

View File

@ -62,7 +62,9 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
@ -142,7 +144,6 @@ class PjitInfo(NamedTuple):
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun: Callable
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
@ -168,11 +169,17 @@ class PjitInfo(NamedTuple):
has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit
# Hash and compare PjitInfo by identity when used as a cache key.
def __hash__(self):
return id(self)
def _python_pjit_helper(jit_info, *args, **kwargs):
p = _infer_params(jit_info, args, kwargs)
def __eq__(self, other):
return self is other
def _python_pjit_helper(fun, jit_info, *args, **kwargs):
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
args_flat = p.args_flat
for arg in args_flat:
dispatch.check_arg(arg)
@ -185,7 +192,6 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, p.arg_names)
@ -304,6 +310,7 @@ def _read_pgle_profiler(jaxpr):
def _cpp_pjit_evict_fn(self):
self._clear_cache()
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
_infer_params_cached.cache_clear()
# The entries are doubled here from the default 4096 because _pjit_call_impl
@ -319,12 +326,12 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding):
return _cpp_pjit_cache
def _cpp_pjit(jit_info: PjitInfo):
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@api_boundary
def cache_miss(*args, **kwargs):
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
jit_info, *args, **kwargs)
fun, jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
pgle_profiler = _read_pgle_profiler(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
@ -334,7 +341,6 @@ def _cpp_pjit(jit_info: PjitInfo):
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
fun = jit_info.fun
cpp_pjit_f = xc._xla.pjit(
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
@ -448,7 +454,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
in_shardings, out_shardings, device, backend)
return PjitInfo(
fun=fun,
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
user_specified_in_shardings=user_specified_in_shardings,
@ -469,7 +474,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
use_resource_env=use_resource_env)
def _make_jit_wrapper(jit_info: PjitInfo):
def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
@api_boundary
def lower(*args, **kwargs):
@ -478,7 +483,6 @@ def _make_jit_wrapper(jit_info: PjitInfo):
return traced.lower()
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__',
getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
@ -487,7 +491,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
@api_boundary
def eval_shape(*args, **kwargs):
p = _infer_params(jit_info, args, kwargs)
p, _ = _infer_params(fun, jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
@ -496,16 +500,16 @@ def _make_jit_wrapper(jit_info: PjitInfo):
@api_boundary
def trace(*args, **kwargs) -> stages.Traced:
p = _infer_params(jit_info, args, kwargs)
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
lower_callable = partial(_resolve_and_lower, p.args_flat, **p.params,
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
pgle_profiler=None)
return stages.Traced(
p.params['jaxpr'], args_info, p.params["name"],p.out_tree,
lower_callable, p.args_flat, p.arg_names, p.num_consts)
lower_callable, args_flat, p.arg_names, p.num_consts)
wrapped = _cpp_pjit(jit_info)
wrapped = _cpp_pjit(fun, jit_info)
wrapped.lower = lower
wrapped.eval_shape = eval_shape
wrapped.trace = trace
@ -525,11 +529,11 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env)
return _make_jit_wrapper(jit_info)
return _make_jit_wrapper(fun, jit_info)
class PjitParams(NamedTuple):
args_flat: list[Any]
consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive
params: dict[str, Any]
in_avals: tuple[core.AbstractValue, ...]
in_tree: PyTreeDef
@ -540,34 +544,34 @@ class PjitParams(NamedTuple):
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
def _infer_params(
ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> PjitParams:
def _infer_params_impl(
fun: Callable,
ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
args: tuple[Any, ...],
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
) -> tuple[PjitParams, list[Any]]:
have_kwargs = bool(kwargs)
if have_kwargs and ji.user_specified_in_shardings:
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
if ji.use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
if pjit_mesh is not None:
jit_name = 'pjit'
if (ji.backend or ji.device) and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
else:
resource_env = None
pjit_mesh = None
jit_name = 'jit'
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
f = lu.wrap_init(ji.fun)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args
@ -608,7 +612,7 @@ def _infer_params(
if config.dynamic_shapes.value:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)
else:
elif in_avals is None:
avals = []
for i, a in enumerate(explicit_args):
try:
@ -621,6 +625,8 @@ def _infer_params(
f"computation, whose {arg_path}."
) from e
in_type = in_avals = tuple(avals)
else:
in_type = in_avals
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
@ -667,9 +673,78 @@ def _infer_params(
keep_unused=ji.keep_unused,
inline=ji.inline,
)
return PjitParams(consts + args_flat, params, in_avals, in_tree, out_tree(),
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
donated_invars, dbg.arg_names if dbg else None, len(consts),
attrs_tracked)
attrs_tracked), args_flat
class InferParamsCacheEntry:
"""Mutable value object for _infer_params_cached."""
__slots__ = ['pjit_params']
pjit_params: PjitParams | None
def __init__(self):
self.pjit_params = None
# We use an outer cache that is keyed on the signature of the arguments, but
# when populating a cache entry using _infer_params_impl, we need to provide
# actual arguments. In principle we could refactor _infer_params_impl to look
# only at an argument signature instead of args/kwargs in those cases that we
# cache, but this was a more minimal change.
@util.weakref_lru_cache
def _infer_params_cached(
fun: Callable,
jit_info: PjitInfo,
signature: jax_jit.ArgumentSignature,
in_avals: tuple[core.AbstractValue, ...],
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
) -> InferParamsCacheEntry:
return InferParamsCacheEntry()
def _infer_params(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
if ji.use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
else:
resource_env = None
pjit_mesh = None
skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value
if not skip_cache:
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
try:
avals = tuple(shaped_abstractify(a) for a in dynargs)
except (OverflowError, TypeError):
# If we see something we don't understand, use the slow path.
skip_cache = True
if skip_cache:
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
kwargs, in_avals=None)
return p, p.consts + args_flat
entry = _infer_params_cached(
fun, ji, signature, avals, pjit_mesh, resource_env)
if entry.pjit_params is None:
p, args_flat = _infer_params_impl(
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
if p.attrs_tracked:
# If there are attrs_tracked, don't use the cache.
return p, p.consts + args_flat
else:
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs
def _extract_implicit_args(

View File

@ -15,6 +15,7 @@
# pyformat: disable
from __future__ import annotations
import collections
from collections.abc import Generator, Iterable, Sequence
from contextlib import ExitStack, contextmanager
import datetime
@ -315,6 +316,21 @@ def count_jit_tracing_cache_miss():
finally:
pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr
@contextmanager
def count_jit_infer_params_cache_miss():
original_infer_params_impl = pjit_lib._infer_params_impl
count = collections.defaultdict(int)
def infer_params_impl_and_count(fun, *args, **kw):
count[fun] += 1
return original_infer_params_impl(fun, *args, **kw)
pjit_lib._infer_params_impl = infer_params_impl_and_count
try:
yield count
finally:
pjit_lib._infer_params_impl = original_infer_params_impl
@contextmanager
def count_aot_jit_cpp_cache_miss():

View File

@ -53,6 +53,7 @@ from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src import debugging
from jax._src import pjit as pjit_lib
from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
@ -2588,7 +2589,7 @@ class APITest(jtu.JaxTestCase):
def test_eval_shape_trace_cache_share(self):
def f(x):
return x * 2
return x
inp = np.arange(8)
@ -2596,8 +2597,32 @@ class APITest(jtu.JaxTestCase):
jax.eval_shape(f, inp)
jax.jit(f)(inp)
# one for `f` and another for mul (`x * 2`) which is jitted.
self.assertEqual(count[0], 2)
self.assertEqual(count[0], 1)
@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
def test_jit_infer_params_cache(self):
def f(x):
return x
f_jit = jax.jit(f)
def g(x):
x = f_jit(x) # noqa: F821
x = f_jit(x) # noqa: F821
return x
g_jit = jax.jit(g)
inp = np.arange(8)
with jtu.count_jit_infer_params_cache_miss() as count:
g_jit(inp)
self.assertDictEqual(count, {f: 1, g: 1})
cache_size = pjit_lib._infer_params_cached.cache_info().currsize
del count, f, f_jit, g, g_jit
# Cache should only keep a weak reference to f and g.
self.assertLess(pjit_lib._infer_params_cached.cache_info().currsize,
cache_size, msg=pjit_lib._infer_params_cached.cache_keys())
def test_eval_shape_out_shardings(self):
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])