roll back to avoid weakref constraints

PiperOrigin-RevId: 513641366
This commit is contained in:
Matthew Johnson 2023-03-02 14:31:19 -08:00 committed by jax authors
parent 33c0a103c6
commit bd9c7bf81c
9 changed files with 88 additions and 144 deletions

View File

@ -401,15 +401,9 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
results.append((v.aval, 'from a constant'))
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_info is not None:
arg_name, arg_path = arg_info[i]
src = f'from the argument {arg_name}{arg_path.pprint("")}'
else:
src = 'from the argument at flattened index {i}'
src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
results.append((v.aval, src))
for eqn in jaxpr.eqns:

View File

@ -2984,12 +2984,10 @@ def make_jaxpr(fun: Callable,
in_type = tuple(zip(in_avals, keep_inputs))
f, out_tree = flatten_fun(f, in_tree)
f = lu.annotate(f, in_type)
debug_info = pe.debug_info(fun, in_tree, True, 'make_jaxpr')
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
f, debug_info=debug_info)
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_avals, _ = unzip2(out_type)

View File

@ -391,7 +391,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun, in_tree, in_avals):
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, False, 'checkify')
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree()

View File

@ -32,7 +32,7 @@ from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
NamedTuple, Optional, Sequence, Set, Tuple, Type,
Union, cast)
import warnings
from weakref import ref, ReferenceType
from weakref import ref
import numpy as np
@ -45,11 +45,9 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.tree_util import PyTreeDef
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, HashableWrapper, weakref_lru_cache,
fun_sourceinfo)
HashableFunction, HashableWrapper, weakref_lru_cache)
import jax._src.pretty_printer as pp
from jax._src.lib import jax_jit
from jax._src import traceback_util
@ -68,43 +66,24 @@ Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
class DebugInfo(NamedTuple):
fn_ref: ReferenceType[Callable] # function being traced/staged
in_tree: Optional[PyTreeDef] # caller/constructor might not have this info
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc
# TODO(mattjj): add input type signature
@property
def fn(self):
return self.fn_ref()
@property
def func_src_info(self) -> Optional[str]:
return fun_sourceinfo(self.fn)
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects', '_debug_info']
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects']
_constvars: List[Var]
_invars: List[Var]
_outvars: List[Atom]
_eqns: List[JaxprEqn]
_effects: Effects
_debug_info: Optional[DebugInfo]
constvars = property(lambda self: self._constvars)
invars = property(lambda self: self._invars)
outvars = property(lambda self: self._outvars)
eqns = property(lambda self: self._eqns)
effects = property(lambda self: self._effects)
debug_info = property(lambda self: self._debug_info)
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: Optional[DebugInfo] = None):
effects: Effects = no_effects):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -121,7 +100,6 @@ class Jaxpr:
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
def __str__(self):
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))

View File

@ -666,8 +666,7 @@ def lower_jaxpr_to_module(
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
arg_names: Optional[Sequence[str]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -744,8 +743,7 @@ def lower_jaxpr_to_module(
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_shardings, result_shardings=result_shardings,
input_output_aliases=input_output_aliases,
arg_names=arg_names)
input_output_aliases=input_output_aliases)
if not ctx.module.operation.verify():
module_string = module_to_string(ctx.module)
@ -862,7 +860,6 @@ def lower_jaxpr_to_fun(
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
api_name: str = 'jit',
arg_names: Optional[Sequence[str]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -983,10 +980,6 @@ def lower_jaxpr_to_fun(
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
if config.jax_jit_pjit_api_merge and arg_names:
for attrs, name_ in zip(arg_attrs[num_dim_vars + num_tokens:], arg_names):
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])

View File

@ -3052,10 +3052,6 @@ def lower_sharding_computation(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)
if i in kept_var_idx]
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3068,8 +3064,7 @@ def lower_sharding_computation(
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings,
arg_names=arg_names)
result_shardings=out_op_shardings)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,

View File

@ -14,7 +14,6 @@
import functools
from functools import partial
import inspect
import itertools as it
import logging
import operator
@ -536,15 +535,3 @@ def use_cpp_method(is_enabled=True):
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
)
return decorator
def fun_sourceinfo(fun: Callable) -> Optional[str]:
while isinstance(fun, partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None

View File

@ -41,13 +41,12 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
raise_to_shaped, Atom, JaxprEqn, Primitive,
ShapedArray, DShapedArray, mapped_aval,
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, DebugInfo)
InputType, OutputType, get_referent)
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
KeyPath, _generate_key_paths)
tree_leaves)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache,
fun_sourceinfo)
as_hashable_function, weakref_lru_cache)
map, unsafe_map = safe_map, map
@ -952,7 +951,7 @@ def tracers_to_jaxpr(
outvars = map(get_atom, out_tracers) # type: ignore[arg-type]
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
outvars, eqns, jaxpr_effects, None)
outvars, eqns, jaxpr_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@ -964,7 +963,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -977,7 +976,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
constvars, invars = split_list(jaxpr.invars, [n])
lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -1383,8 +1382,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects,
jaxpr.debug_info)
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects)
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs
@ -1488,22 +1486,12 @@ class DynamicJaxprTracer(core.Tracer):
if dbg is None:
return ""
origin = ("The error occurred while tracing the function "
f"{fun_sourceinfo(dbg.fn)} for {dbg.traced_for}. ")
arg_info = arg_info_all(dbg)
if invar_pos and arg_info:
arg_info = [arg_info[i] for i in invar_pos]
arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info]
if len(arg_names) == 1:
arg_info_str = f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
arg_info_str = f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
arg_info_str = f"the arguments {', '.join(rest)}, and {last}"
origin = (f"The error occurred while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}. ")
if invar_pos:
origin += ("This concrete value was not available in Python because it "
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {arg_info_str}.")
f"of {dbg.arg_info(invar_pos)}.")
elif progenitor_eqns:
msts = [" operation "
f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n"
@ -1573,8 +1561,7 @@ class JaxprStackFrame:
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects,
self.debug_info)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
@ -1587,7 +1574,7 @@ class JaxprStackFrame:
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
jaxpr_effects, self.debug_info)
jaxpr_effects)
# We can't run check_jaxpr until after we normalize.
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
@ -1651,7 +1638,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns,
jaxpr_effects, jaxpr.debug_info)
jaxpr_effects)
return new_jaxpr, new_constvals
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
@ -1701,7 +1688,7 @@ def _inline_literals(jaxpr, constvals):
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
jaxpr_effects, jaxpr.debug_info)
jaxpr_effects)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
@ -1841,8 +1828,7 @@ class DynamicJaxprTrace(core.Trace):
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals,
debug_info=debug_info_final(f, map_primitive.name))
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -1976,36 +1962,65 @@ def _memoize(thunk):
return memoized
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
return DebugInfo(ref(fn), in_tree, has_kwargs, traced_for)
class DebugInfo(NamedTuple):
func_src_info: str
traced_for: str
arg_info: Callable[[int], str]
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return DebugInfo(ref(fn.f), in_tree, has_kwargs, traced_for)
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
def arg_info(dbg: DebugInfo, flat_pos: int) -> Optional[Tuple[str, KeyPath]]:
infos = arg_info_all(dbg)
return None if infos is None else infos[flat_pos]
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
func_src_info = fun_sourceinfo(fn)
if in_tree is not None:
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
else:
arg_info = arg_info_flattened # type: ignore
return DebugInfo(func_src_info, traced_for, arg_info)
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
ba = None if dbg.in_tree is None else sig_info(dbg)
if ba is None: return None
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
for key_path, _ in _generate_key_paths(dummy_arg)]
def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
if dbg.in_tree is None or dbg.fn is None: return None
def fun_sourceinfo(fun: Callable):
while isinstance(fun, functools.partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves)
except:
return None
args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {})
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
line_info = f"{fun.__name__} at {filename}:{lineno}"
return line_info
except AttributeError:
return "<unknown>"
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
flat_pos: List[int]) -> str:
dummy_args = [False] * in_tree.num_leaves
for i in flat_pos: dummy_args[i] = True
if has_kwargs:
args, kwargs = tree_unflatten(in_tree, dummy_args)
else:
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
try:
return inspect.signature(dbg.fn).bind(*args, **kwargs)
ba = inspect.signature(fn).bind(*args, **kwargs)
except (TypeError, ValueError):
return None
return arg_info_flattened(flat_pos)
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
if any(tree_leaves(x))]
if len(arg_names) == 1:
return f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
return f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
return f"the arguments {', '.join(rest)}, and {last}"
def arg_info_flattened(flat_pos: List[int]) -> str:
if len(flat_pos) > 1:
return f"the argument passed at flattened positions {flat_pos}"
else:
return f"the argument passed at flattened position {flat_pos[0]}"
@profiler.annotate_function
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
@ -2237,7 +2252,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
out_type = tuple(zip(out_avals, kept_outs))
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
jaxpr.effects, jaxpr.debug_info)
jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return new_jaxpr, out_type

View File

@ -1145,22 +1145,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.runtime_executable())
self.assertIsNotNone(g.runtime_executable())
def test_jit_lower_arg_info(self):
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
raise unittest.SkipTest("test only applies after jit-pjit api merge")
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
mhlo_str = str(lowered.compiler_ir('mhlo'))
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
def test_jit_enum_as_dict_keys_fails(self):
class E(enum.Enum):
A = 0
@ -3405,7 +3389,7 @@ class APITest(jtu.JaxTestCase):
else:
return 0
msg = r"on the value of the argument x"
msg = r"on the value of the argument 'x'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)
@ -3417,7 +3401,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments x and y"
msg = r"on the values of the arguments 'x' and 'y'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)
@ -3429,7 +3413,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments x and z"
msg = r"on the values of the arguments 'x' and 'z'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
@ -3448,7 +3432,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments args"
msg = r"on the values of the argument 'args'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
@ -3461,7 +3445,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments kwargs"
msg = r"on the values of the argument 'kwargs'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)
@ -3474,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the value of the argument xy"
msg = r"on the value of the argument 'xy'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)
@ -3507,7 +3491,7 @@ class APITest(jtu.JaxTestCase):
def g(x):
return f(x, True)
msg = r"on the value of the argument y"
msg = r"on the value of the argument 'y'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
@ -5200,11 +5184,11 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(res[0][0].shape, (1,))
self.assertEqual(res[0][1], "from a constant")
self.assertEqual(res[1][0].shape, ())
self.assertEqual(res[1][1], "from the argument x[0]")
self.assertEqual(res[1][1], "from the argument 'x'")
self.assertEqual(res[2][0].shape, ())
self.assertEqual(res[2][1], "from the argument x[1]")
self.assertEqual(res[2][1], "from the argument 'x'")
self.assertEqual(res[3][0].shape, ())
self.assertEqual(res[3][1], "from the argument y")
self.assertEqual(res[3][1], "from the argument 'y'")
self.assertEqual(res[4][0].shape, ())
self.assertStartsWith(res[4][1], "named 'z'")
self.assertEqual(res[5][0].shape, ())