mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params. In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals. ``` name old cpu/op new cpu/op delta jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5) name old time/op new time/op delta jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 645491650
This commit is contained in:
parent
a730f6bfd3
commit
9e30079dba
@ -905,5 +905,23 @@ def benchmark_lorentz63_cache_hits(state):
|
||||
jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_add_chain(state):
|
||||
SIZE = 100
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
return lax.add(x, y)
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (2, 2))
|
||||
while state:
|
||||
@jax.jit
|
||||
def f(x):
|
||||
for i in range(SIZE):
|
||||
x = g(x, x)
|
||||
return x
|
||||
f(x).block_until_ready()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -2945,6 +2945,7 @@ def clear_backends():
|
||||
xb.local_devices.cache_clear()
|
||||
xb.process_count.cache_clear()
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
pjit._infer_params_cached.cache_clear()
|
||||
pjit._pjit_lower_cached.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
@ -2970,6 +2971,7 @@ def clear_caches():
|
||||
|
||||
# Clear all C++ compiled executable caches for pjit
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
pjit._infer_params_cached.cache_clear()
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
||||
# Clear all C++ compiled executable caches for pmap
|
||||
|
@ -1739,7 +1739,10 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
frame = self._trace.frame
|
||||
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
|
||||
return self if val is None else get_referent(val)
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
|
||||
|
||||
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
|
||||
return core.raise_to_shaped(x.aval)
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
sentinel = object()
|
||||
|
139
jax/_src/pjit.py
139
jax/_src/pjit.py
@ -62,7 +62,9 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, GSPMDSharding,
|
||||
@ -142,7 +144,6 @@ class PjitInfo(NamedTuple):
|
||||
In other words, this structure contains arguments to jit()/pjit(),
|
||||
preprocessed and validated.
|
||||
"""
|
||||
fun: Callable
|
||||
fun_sourceinfo: str | None
|
||||
fun_signature: inspect.Signature | None
|
||||
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
|
||||
@ -168,11 +169,17 @@ class PjitInfo(NamedTuple):
|
||||
has_explicit_sharding: bool
|
||||
use_resource_env: bool # False for jit, True for pjit
|
||||
|
||||
# Hash and compare PjitInfo by identity when used as a cache key.
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
p = _infer_params(jit_info, args, kwargs)
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
|
||||
def _python_pjit_helper(fun, jit_info, *args, **kwargs):
|
||||
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
|
||||
|
||||
args_flat = p.args_flat
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
|
||||
@ -185,7 +192,6 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
|
||||
fun = jit_info.fun
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, p.arg_names)
|
||||
@ -304,6 +310,7 @@ def _read_pgle_profiler(jaxpr):
|
||||
def _cpp_pjit_evict_fn(self):
|
||||
self._clear_cache()
|
||||
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
|
||||
_infer_params_cached.cache_clear()
|
||||
|
||||
|
||||
# The entries are doubled here from the default 4096 because _pjit_call_impl
|
||||
@ -319,12 +326,12 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding):
|
||||
return _cpp_pjit_cache
|
||||
|
||||
|
||||
def _cpp_pjit(jit_info: PjitInfo):
|
||||
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
|
||||
jit_info, *args, **kwargs)
|
||||
fun, jit_info, *args, **kwargs)
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
pgle_profiler = _read_pgle_profiler(jaxpr)
|
||||
maybe_fastpath_data = _get_fastpath_data(
|
||||
@ -334,7 +341,6 @@ def _cpp_pjit(jit_info: PjitInfo):
|
||||
|
||||
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
|
||||
fun = jit_info.fun
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
@ -448,7 +454,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
in_shardings, out_shardings, device, backend)
|
||||
|
||||
return PjitInfo(
|
||||
fun=fun,
|
||||
fun_sourceinfo=fun_sourceinfo,
|
||||
fun_signature=fun_signature,
|
||||
user_specified_in_shardings=user_specified_in_shardings,
|
||||
@ -469,7 +474,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
use_resource_env=use_resource_env)
|
||||
|
||||
|
||||
def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
@ -478,7 +483,6 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
return traced.lower()
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
fun = jit_info.fun
|
||||
fun_name = getattr(fun, '__qualname__',
|
||||
getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
@ -487,7 +491,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
p = _infer_params(jit_info, args, kwargs)
|
||||
p, _ = _infer_params(fun, jit_info, args, kwargs)
|
||||
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
|
||||
# TODO(yashkatariya): Add `Layout` to SDS.
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
|
||||
@ -496,16 +500,16 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def trace(*args, **kwargs) -> stages.Traced:
|
||||
p = _infer_params(jit_info, args, kwargs)
|
||||
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
|
||||
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
|
||||
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
|
||||
lower_callable = partial(_resolve_and_lower, p.args_flat, **p.params,
|
||||
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
|
||||
pgle_profiler=None)
|
||||
return stages.Traced(
|
||||
p.params['jaxpr'], args_info, p.params["name"],p.out_tree,
|
||||
lower_callable, p.args_flat, p.arg_names, p.num_consts)
|
||||
lower_callable, args_flat, p.arg_names, p.num_consts)
|
||||
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
wrapped = _cpp_pjit(fun, jit_info)
|
||||
wrapped.lower = lower
|
||||
wrapped.eval_shape = eval_shape
|
||||
wrapped.trace = trace
|
||||
@ -525,11 +529,11 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env)
|
||||
return _make_jit_wrapper(jit_info)
|
||||
return _make_jit_wrapper(fun, jit_info)
|
||||
|
||||
|
||||
class PjitParams(NamedTuple):
|
||||
args_flat: list[Any]
|
||||
consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive
|
||||
params: dict[str, Any]
|
||||
in_avals: tuple[core.AbstractValue, ...]
|
||||
in_tree: PyTreeDef
|
||||
@ -540,34 +544,34 @@ class PjitParams(NamedTuple):
|
||||
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
|
||||
|
||||
|
||||
def _infer_params(
|
||||
ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> PjitParams:
|
||||
def _infer_params_impl(
|
||||
fun: Callable,
|
||||
ji: PjitInfo,
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
in_avals: tuple[core.AbstractValue, ...] | None,
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
have_kwargs = bool(kwargs)
|
||||
if have_kwargs and ji.user_specified_in_shardings:
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
if ji.use_resource_env:
|
||||
# We need to fetch the mesh from inside the wrapped function, because
|
||||
# meshes are dynamically scoped (i.e., with a context manager).
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
if pjit_mesh is not None:
|
||||
jit_name = 'pjit'
|
||||
if (ji.backend or ji.device) and not pjit_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
else:
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
jit_name = 'jit'
|
||||
|
||||
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
|
||||
|
||||
dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
||||
ji.static_argnums, ji.static_argnames)
|
||||
f = lu.wrap_init(ji.fun)
|
||||
f = lu.wrap_init(fun)
|
||||
f, res_paths = result_paths(f)
|
||||
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
|
||||
del args
|
||||
@ -608,7 +612,7 @@ def _infer_params(
|
||||
if config.dynamic_shapes.value:
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
||||
in_avals = tuple(a for a, e in in_type if e)
|
||||
else:
|
||||
elif in_avals is None:
|
||||
avals = []
|
||||
for i, a in enumerate(explicit_args):
|
||||
try:
|
||||
@ -621,6 +625,8 @@ def _infer_params(
|
||||
f"computation, whose {arg_path}."
|
||||
) from e
|
||||
in_type = in_avals = tuple(avals)
|
||||
else:
|
||||
in_type = in_avals
|
||||
|
||||
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
in_shardings_treedef, in_shardings_leaves,
|
||||
@ -667,9 +673,78 @@ def _infer_params(
|
||||
keep_unused=ji.keep_unused,
|
||||
inline=ji.inline,
|
||||
)
|
||||
return PjitParams(consts + args_flat, params, in_avals, in_tree, out_tree(),
|
||||
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
|
||||
donated_invars, dbg.arg_names if dbg else None, len(consts),
|
||||
attrs_tracked)
|
||||
attrs_tracked), args_flat
|
||||
|
||||
|
||||
|
||||
class InferParamsCacheEntry:
|
||||
"""Mutable value object for _infer_params_cached."""
|
||||
__slots__ = ['pjit_params']
|
||||
|
||||
pjit_params: PjitParams | None
|
||||
|
||||
def __init__(self):
|
||||
self.pjit_params = None
|
||||
|
||||
|
||||
# We use an outer cache that is keyed on the signature of the arguments, but
|
||||
# when populating a cache entry using _infer_params_impl, we need to provide
|
||||
# actual arguments. In principle we could refactor _infer_params_impl to look
|
||||
# only at an argument signature instead of args/kwargs in those cases that we
|
||||
# cache, but this was a more minimal change.
|
||||
@util.weakref_lru_cache
|
||||
def _infer_params_cached(
|
||||
fun: Callable,
|
||||
jit_info: PjitInfo,
|
||||
signature: jax_jit.ArgumentSignature,
|
||||
in_avals: tuple[core.AbstractValue, ...],
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
) -> InferParamsCacheEntry:
|
||||
return InferParamsCacheEntry()
|
||||
|
||||
|
||||
def _infer_params(
|
||||
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
if ji.use_resource_env:
|
||||
# We need to fetch the mesh from inside the wrapped function, because
|
||||
# meshes are dynamically scoped (i.e., with a context manager).
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
else:
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
|
||||
skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value
|
||||
if not skip_cache:
|
||||
signature, dynargs = jax_jit.parse_arguments(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
ji.static_argnames, tree_util.default_registry)
|
||||
try:
|
||||
avals = tuple(shaped_abstractify(a) for a in dynargs)
|
||||
except (OverflowError, TypeError):
|
||||
# If we see something we don't understand, use the slow path.
|
||||
skip_cache = True
|
||||
|
||||
if skip_cache:
|
||||
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
|
||||
kwargs, in_avals=None)
|
||||
return p, p.consts + args_flat
|
||||
|
||||
entry = _infer_params_cached(
|
||||
fun, ji, signature, avals, pjit_mesh, resource_env)
|
||||
if entry.pjit_params is None:
|
||||
p, args_flat = _infer_params_impl(
|
||||
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
|
||||
if p.attrs_tracked:
|
||||
# If there are attrs_tracked, don't use the cache.
|
||||
return p, p.consts + args_flat
|
||||
else:
|
||||
entry.pjit_params = p
|
||||
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
||||
|
||||
|
||||
def _extract_implicit_args(
|
||||
|
@ -15,6 +15,7 @@
|
||||
# pyformat: disable
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from contextlib import ExitStack, contextmanager
|
||||
import datetime
|
||||
@ -315,6 +316,21 @@ def count_jit_tracing_cache_miss():
|
||||
finally:
|
||||
pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr
|
||||
|
||||
@contextmanager
|
||||
def count_jit_infer_params_cache_miss():
|
||||
original_infer_params_impl = pjit_lib._infer_params_impl
|
||||
count = collections.defaultdict(int)
|
||||
|
||||
def infer_params_impl_and_count(fun, *args, **kw):
|
||||
count[fun] += 1
|
||||
return original_infer_params_impl(fun, *args, **kw)
|
||||
|
||||
pjit_lib._infer_params_impl = infer_params_impl_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pjit_lib._infer_params_impl = original_infer_params_impl
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_aot_jit_cpp_cache_miss():
|
||||
|
@ -53,6 +53,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import debugging
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -2588,7 +2589,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_eval_shape_trace_cache_share(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
return x
|
||||
|
||||
inp = np.arange(8)
|
||||
|
||||
@ -2596,8 +2597,32 @@ class APITest(jtu.JaxTestCase):
|
||||
jax.eval_shape(f, inp)
|
||||
jax.jit(f)(inp)
|
||||
|
||||
# one for `f` and another for mul (`x * 2`) which is jitted.
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
|
||||
def test_jit_infer_params_cache(self):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
f_jit = jax.jit(f)
|
||||
|
||||
def g(x):
|
||||
x = f_jit(x) # noqa: F821
|
||||
x = f_jit(x) # noqa: F821
|
||||
return x
|
||||
|
||||
g_jit = jax.jit(g)
|
||||
|
||||
inp = np.arange(8)
|
||||
with jtu.count_jit_infer_params_cache_miss() as count:
|
||||
g_jit(inp)
|
||||
|
||||
self.assertDictEqual(count, {f: 1, g: 1})
|
||||
cache_size = pjit_lib._infer_params_cached.cache_info().currsize
|
||||
del count, f, f_jit, g, g_jit
|
||||
# Cache should only keep a weak reference to f and g.
|
||||
self.assertLess(pjit_lib._infer_params_cached.cache_info().currsize,
|
||||
cache_size, msg=pjit_lib._infer_params_cached.cache_keys())
|
||||
|
||||
def test_eval_shape_out_shardings(self):
|
||||
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user