Fix buggy and confusing logic in the C++/pjit caching path.

When we have a cache miss in `_cpp_pjit` we want to compile the function and
store the executable. Previously we had a roundabout way of getting hold of that
executable. We'd trace the function to a jaxpr but we wouldn't lower and compile
it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion
would be peeled off and eventually we'd hit the `pjit_p` impl rule,
`_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that*
cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store
the executable in `most_recent_pjit_call_executable`. We'd eventually pop the
stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the
compiled object by looking up `most_recent_pjit_call_executable`.

There's room for bugs here if we hit one cache but not the other. For example,
if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we
won't compile the executable. Normally that would just mean that the `_cpp_pjit`
cache won't be populated. But if we've previously hit a function with the same
jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll
get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect
cache entry. The divergent cache behavior you need to trigger this started
happening with the "stackless" change because the tracing context became a
bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in
general have different tracing contexts.

With this change, we remove the whole `most_recent_pjit_call_executable` system.
Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the
executable directly rather than calling into `pjit_p.bind`. We do call into
`pjit_p.bind` if we're not in an eval context, but in that case we don't expect
to be able to populate the `_cpp_pjit` cache anyway.
This commit is contained in:
Dougal 2024-11-10 18:07:31 -05:00
parent a041ea152e
commit 763952a607
3 changed files with 45 additions and 39 deletions

View File

@ -892,6 +892,11 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
def check_eval_args(args):
for arg in args:
if isinstance(arg, Tracer):
raise escaped_tracer_error(arg)
class EvalTrace(Trace):
def process_primitive(self, primitive, args, params):
@ -902,12 +907,11 @@ class EvalTrace(Trace):
else:
# TODO(dougalm): delete. this shouldn't be necessary
args = map(full_lower, args)
for arg in args:
if isinstance(arg, Tracer):
if config.data_dependent_tracing_fallback.value:
if config.data_dependent_tracing_fallback.value:
for arg in args:
if isinstance(arg, Tracer):
return primitive.bind_with_trace(arg._trace, args, params)
else:
raise escaped_tracer_error(arg)
check_eval_args(args)
return primitive.impl(*args, **params)
def process_call(self, primitive, f, tracers, params):

View File

@ -23,7 +23,6 @@ import logging
import operator as op
import weakref
from typing import NamedTuple, Any, Union, cast
import threading
import warnings
import numpy as np
@ -185,7 +184,16 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
args_flat = [*init_states, *args_flat]
try:
out_flat = pjit_p.bind(*args_flat, **p.params)
if (core.trace_state_clean() and
not config.debug_key_reuse.value and
not config.data_dependent_tracing_fallback.value):
args_flat = map(core.full_lower, args_flat)
core.check_eval_args(args_flat)
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
else:
out_flat = pjit_p.bind(*args_flat, **p.params)
compiled = None
profiler = None
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
@ -215,7 +223,8 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
_set_states(p.attrs_tracked, final_states)
outs = tree_unflatten(p.out_tree, out_flat)
return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked
return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'],
p.attrs_tracked, compiled, profiler)
def _set_states(attrs_tracked, vals):
@ -286,21 +295,6 @@ def _get_fastpath_data(
return fastpath_data
class _MostRecentPjitCallExecutable(threading.local):
def __init__(self):
self.weak_key_dict = weakref.WeakKeyDictionary()
self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary()
_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()
def _read_most_recent_pjit_call_executable(jaxpr):
return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None)
def _read_pgle_profiler(jaxpr):
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None)
def _cpp_pjit_evict_fn(self):
self._clear_cache()
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
@ -335,10 +329,9 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
"`jit`, but 'no_tracing' is set")
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
outs, out_flat, out_tree, args_flat, jaxpr, \
attrs_tracked, executable, pgle_profiler = _python_pjit_helper(
fun, jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
pgle_profiler = _read_pgle_profiler(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
jaxpr.consts, jit_info.abstracted_axes,
@ -1619,17 +1612,11 @@ def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
global _most_recent_pjit_call_executable
pgle_compile_options, pgle_profiler = {}, None
pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
if jaxpr not in pgle_profiler_dict:
pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler(
config.pgle_profiling_runs.value,
config.pgle_aggregation_percentile.value)
pgle_profiler = pgle_profiler_dict[jaxpr]
pgle_profiler = profiler.PGLEProfiler(
config.pgle_profiling_runs.value,
config.pgle_aggregation_percentile.value)
# The method below will return FDO profile when module was profiled
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
# be None.
@ -1652,7 +1639,6 @@ def _pjit_call_impl_python(
compiler_options_kvs=compiler_options_kvs,
).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_array_xla_sharding_layout_match(
@ -1674,7 +1660,7 @@ def _pjit_call_impl_python(
("abstract args", map(xla.abstractify, args)),
("fingerprint", fingerprint))
try:
return compiled.unsafe_call(*args), compiled
return compiled.unsafe_call(*args), compiled, pgle_profiler
except FloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
@ -1720,13 +1706,12 @@ def _pjit_call_impl(*args, jaxpr,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
def call_impl_cache_miss(*args_, **kwargs_):
out_flat, compiled = _pjit_call_impl_python(
out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline, compiler_options_kvs=compiler_options_kvs)
pgle_profiler = _read_pgle_profiler(jaxpr)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
jaxpr.consts, None, pgle_profiler)

View File

@ -1292,6 +1292,23 @@ class PJitTest(jtu.BufferDonationTestCase):
with self.assertRaisesRegex(ValueError, "spmd_axis_name"):
jax.vmap(f, spmd_axis_name='x')(xs)
def test_cache_bug(self):
devices = list(jax.devices())
if len(devices) < 2:
raise unittest.SkipTest("Test requires 2 devices")
def under_jvp(f):
return jax.jvp(f, (), ())
x0 = jnp.zeros(1, device=devices[0])
x1 = jnp.zeros(1, device=devices[1])
# comments describe how caches worked under the old `_most_recent_pjit_call_executable` system
under_jvp(lambda: jnp.sin(x0)) # cpp_pjit miss, pjit_call_impl miss
jnp.sin(x1) # cpp_pjit miss, pjit_call_impl miss
ans1 = jnp.sin(x0) # cpp_pjit miss, pjit_call_impl hit. Bad cpp_pjit entry created
ans2 = jnp.sin(x0) # cpp_pjit hit with bad cache entry
assert(ans1.devices() == ans2.devices())
@jtu.pytest_mark_if_available('multiaccelerator')
class CustomPartitionerTest(jtu.JaxTestCase):