mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix buggy and confusing logic in the C++/pjit caching path.
When we have a cache miss in `_cpp_pjit` we want to compile the function and store the executable. Previously we had a roundabout way of getting hold of that executable. We'd trace the function to a jaxpr but we wouldn't lower and compile it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion would be peeled off and eventually we'd hit the `pjit_p` impl rule, `_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that* cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store the executable in `most_recent_pjit_call_executable`. We'd eventually pop the stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the compiled object by looking up `most_recent_pjit_call_executable`. There's room for bugs here if we hit one cache but not the other. For example, if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we won't compile the executable. Normally that would just mean that the `_cpp_pjit` cache won't be populated. But if we've previously hit a function with the same jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect cache entry. The divergent cache behavior you need to trigger this started happening with the "stackless" change because the tracing context became a bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in general have different tracing contexts. With this change, we remove the whole `most_recent_pjit_call_executable` system. Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the executable directly rather than calling into `pjit_p.bind`. We do call into `pjit_p.bind` if we're not in an eval context, but in that case we don't expect to be able to populate the `_cpp_pjit` cache anyway.
This commit is contained in:
parent
a041ea152e
commit
763952a607
@ -892,6 +892,11 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
aval_property = namedtuple("aval_property", ["fget"])
|
||||
aval_method = namedtuple("aval_method", ["fun"])
|
||||
|
||||
def check_eval_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, Tracer):
|
||||
raise escaped_tracer_error(arg)
|
||||
|
||||
class EvalTrace(Trace):
|
||||
|
||||
def process_primitive(self, primitive, args, params):
|
||||
@ -902,12 +907,11 @@ class EvalTrace(Trace):
|
||||
else:
|
||||
# TODO(dougalm): delete. this shouldn't be necessary
|
||||
args = map(full_lower, args)
|
||||
for arg in args:
|
||||
if isinstance(arg, Tracer):
|
||||
if config.data_dependent_tracing_fallback.value:
|
||||
if config.data_dependent_tracing_fallback.value:
|
||||
for arg in args:
|
||||
if isinstance(arg, Tracer):
|
||||
return primitive.bind_with_trace(arg._trace, args, params)
|
||||
else:
|
||||
raise escaped_tracer_error(arg)
|
||||
check_eval_args(args)
|
||||
return primitive.impl(*args, **params)
|
||||
|
||||
def process_call(self, primitive, f, tracers, params):
|
||||
|
@ -23,7 +23,6 @@ import logging
|
||||
import operator as op
|
||||
import weakref
|
||||
from typing import NamedTuple, Any, Union, cast
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -185,7 +184,16 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
|
||||
args_flat = [*init_states, *args_flat]
|
||||
|
||||
try:
|
||||
out_flat = pjit_p.bind(*args_flat, **p.params)
|
||||
if (core.trace_state_clean() and
|
||||
not config.debug_key_reuse.value and
|
||||
not config.data_dependent_tracing_fallback.value):
|
||||
args_flat = map(core.full_lower, args_flat)
|
||||
core.check_eval_args(args_flat)
|
||||
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
|
||||
else:
|
||||
out_flat = pjit_p.bind(*args_flat, **p.params)
|
||||
compiled = None
|
||||
profiler = None
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
|
||||
@ -215,7 +223,8 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
|
||||
_set_states(p.attrs_tracked, final_states)
|
||||
|
||||
outs = tree_unflatten(p.out_tree, out_flat)
|
||||
return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked
|
||||
return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'],
|
||||
p.attrs_tracked, compiled, profiler)
|
||||
|
||||
|
||||
def _set_states(attrs_tracked, vals):
|
||||
@ -286,21 +295,6 @@ def _get_fastpath_data(
|
||||
return fastpath_data
|
||||
|
||||
|
||||
class _MostRecentPjitCallExecutable(threading.local):
|
||||
def __init__(self):
|
||||
self.weak_key_dict = weakref.WeakKeyDictionary()
|
||||
self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary()
|
||||
|
||||
_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()
|
||||
|
||||
|
||||
def _read_most_recent_pjit_call_executable(jaxpr):
|
||||
return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None)
|
||||
|
||||
|
||||
def _read_pgle_profiler(jaxpr):
|
||||
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None)
|
||||
|
||||
def _cpp_pjit_evict_fn(self):
|
||||
self._clear_cache()
|
||||
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
|
||||
@ -335,10 +329,9 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
if config.no_tracing.value:
|
||||
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
|
||||
"`jit`, but 'no_tracing' is set")
|
||||
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
|
||||
outs, out_flat, out_tree, args_flat, jaxpr, \
|
||||
attrs_tracked, executable, pgle_profiler = _python_pjit_helper(
|
||||
fun, jit_info, *args, **kwargs)
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
pgle_profiler = _read_pgle_profiler(jaxpr)
|
||||
maybe_fastpath_data = _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
|
||||
jaxpr.consts, jit_info.abstracted_axes,
|
||||
@ -1619,17 +1612,11 @@ def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
pgle_compile_options, pgle_profiler = {}, None
|
||||
pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
if jaxpr not in pgle_profiler_dict:
|
||||
pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler(
|
||||
config.pgle_profiling_runs.value,
|
||||
config.pgle_aggregation_percentile.value)
|
||||
|
||||
pgle_profiler = pgle_profiler_dict[jaxpr]
|
||||
pgle_profiler = profiler.PGLEProfiler(
|
||||
config.pgle_profiling_runs.value,
|
||||
config.pgle_aggregation_percentile.value)
|
||||
# The method below will return FDO profile when module was profiled
|
||||
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
|
||||
# be None.
|
||||
@ -1652,7 +1639,6 @@ def _pjit_call_impl_python(
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
).compile()
|
||||
|
||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||
pxla.check_array_xla_sharding_layout_match(
|
||||
@ -1674,7 +1660,7 @@ def _pjit_call_impl_python(
|
||||
("abstract args", map(xla.abstractify, args)),
|
||||
("fingerprint", fingerprint))
|
||||
try:
|
||||
return compiled.unsafe_call(*args), compiled
|
||||
return compiled.unsafe_call(*args), compiled, pgle_profiler
|
||||
except FloatingPointError as e:
|
||||
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
|
||||
|
||||
@ -1720,13 +1706,12 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def call_impl_cache_miss(*args_, **kwargs_):
|
||||
out_flat, compiled = _pjit_call_impl_python(
|
||||
out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
|
||||
*args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline, compiler_options_kvs=compiler_options_kvs)
|
||||
pgle_profiler = _read_pgle_profiler(jaxpr)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
jaxpr.consts, None, pgle_profiler)
|
||||
|
@ -1292,6 +1292,23 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "spmd_axis_name"):
|
||||
jax.vmap(f, spmd_axis_name='x')(xs)
|
||||
|
||||
def test_cache_bug(self):
|
||||
devices = list(jax.devices())
|
||||
if len(devices) < 2:
|
||||
raise unittest.SkipTest("Test requires 2 devices")
|
||||
|
||||
def under_jvp(f):
|
||||
return jax.jvp(f, (), ())
|
||||
|
||||
x0 = jnp.zeros(1, device=devices[0])
|
||||
x1 = jnp.zeros(1, device=devices[1])
|
||||
|
||||
# comments describe how caches worked under the old `_most_recent_pjit_call_executable` system
|
||||
under_jvp(lambda: jnp.sin(x0)) # cpp_pjit miss, pjit_call_impl miss
|
||||
jnp.sin(x1) # cpp_pjit miss, pjit_call_impl miss
|
||||
ans1 = jnp.sin(x0) # cpp_pjit miss, pjit_call_impl hit. Bad cpp_pjit entry created
|
||||
ans2 = jnp.sin(x0) # cpp_pjit hit with bad cache entry
|
||||
assert(ans1.devices() == ans2.devices())
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user