mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
attach debug info to jaxpr, pass to mlir/mhlo
Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
parent
a002643a4a
commit
8440e27a5a
@ -401,9 +401,15 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
results.append((v.aval, 'from a constant'))
|
||||
|
||||
assert len(jaxpr.invars) == len(in_leaves)
|
||||
dbg = pe.debug_info(f, in_tree, True, "saved_residuals")
|
||||
arg_info = pe.arg_info_all(dbg)
|
||||
for i, v in enumerate(jaxpr.invars):
|
||||
if v in res_vars:
|
||||
src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
|
||||
if arg_info is not None:
|
||||
arg_name, arg_path = arg_info[i]
|
||||
src = f'from the argument {arg_name}{arg_path.pprint("")}'
|
||||
else:
|
||||
src = 'from the argument at flattened index {i}'
|
||||
results.append((v.aval, src))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
|
@ -2978,10 +2978,12 @@ def make_jaxpr(fun: Callable,
|
||||
in_type = tuple(zip(in_avals, keep_inputs))
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
f = lu.annotate(f, in_type)
|
||||
debug_info = pe.debug_info(fun, in_tree, True, 'make_jaxpr')
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
||||
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
f, debug_info=debug_info)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
if return_shape:
|
||||
out_avals, _ = unzip2(out_type)
|
||||
|
@ -391,7 +391,7 @@ def initial_style_jaxpr(
|
||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
|
||||
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
|
||||
debug = pe.debug_info(fun, in_tree, False, 'checkify')
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
|
@ -32,7 +32,7 @@ from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
|
||||
NamedTuple, Optional, Sequence, Set, Tuple, Type,
|
||||
Union, cast)
|
||||
import warnings
|
||||
from weakref import ref
|
||||
from weakref import ref, ReferenceType
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -45,9 +45,11 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
from jax._src.tree_util import PyTreeDef
|
||||
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
|
||||
tuple_delete, as_hashable_function,
|
||||
HashableFunction, HashableWrapper, weakref_lru_cache)
|
||||
HashableFunction, HashableWrapper, weakref_lru_cache,
|
||||
fun_sourceinfo)
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src import traceback_util
|
||||
@ -66,24 +68,43 @@ Effects = effects.Effects
|
||||
EffectTypeSet = effects.EffectTypeSet
|
||||
no_effects: Effects = effects.no_effects
|
||||
|
||||
class DebugInfo(NamedTuple):
|
||||
fn_ref: ReferenceType[Callable] # function being traced/staged
|
||||
in_tree: Optional[PyTreeDef] # caller/constructor might not have this info
|
||||
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
|
||||
traced_for: str # "jit", "scan", "make_jaxpr", etc
|
||||
# TODO(mattjj): add input type signature
|
||||
|
||||
@property
|
||||
def fn(self):
|
||||
return self.fn_ref()
|
||||
|
||||
@property
|
||||
def func_src_info(self) -> Optional[str]:
|
||||
return fun_sourceinfo(self.fn)
|
||||
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects']
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects', '_debug_info']
|
||||
|
||||
_constvars: List[Var]
|
||||
_invars: List[Var]
|
||||
_outvars: List[Atom]
|
||||
_eqns: List[JaxprEqn]
|
||||
_effects: Effects
|
||||
_debug_info: Optional[DebugInfo]
|
||||
|
||||
constvars = property(lambda self: self._constvars)
|
||||
invars = property(lambda self: self._invars)
|
||||
outvars = property(lambda self: self._outvars)
|
||||
eqns = property(lambda self: self._eqns)
|
||||
effects = property(lambda self: self._effects)
|
||||
debug_info = property(lambda self: self._debug_info)
|
||||
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
||||
effects: Effects = no_effects):
|
||||
effects: Effects = no_effects,
|
||||
debug_info: Optional[DebugInfo] = None):
|
||||
"""
|
||||
Args:
|
||||
constvars: list of variables introduced for constants. Array constants are
|
||||
@ -100,6 +121,7 @@ class Jaxpr:
|
||||
self._outvars = list(outvars)
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info
|
||||
|
||||
def __str__(self):
|
||||
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))
|
||||
|
@ -666,7 +666,8 @@ def lower_jaxpr_to_module(
|
||||
donated_args: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
arg_names: Optional[Sequence[str]] = None,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
@ -743,7 +744,8 @@ def lower_jaxpr_to_module(
|
||||
num_output_tokens=0,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
||||
input_output_aliases=input_output_aliases)
|
||||
input_output_aliases=input_output_aliases,
|
||||
arg_names=arg_names)
|
||||
|
||||
if not ctx.module.operation.verify():
|
||||
module_string = module_to_string(ctx.module)
|
||||
@ -860,6 +862,7 @@ def lower_jaxpr_to_fun(
|
||||
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
|
||||
num_output_tokens: int = 0,
|
||||
api_name: str = 'jit',
|
||||
arg_names: Optional[Sequence[str]] = None,
|
||||
) -> func_dialect.FuncOp:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
@ -980,6 +983,10 @@ def lower_jaxpr_to_fun(
|
||||
if alias is not None:
|
||||
attrs["tf.aliasing_output"] = i32_attr(alias)
|
||||
|
||||
if config.jax_jit_pjit_api_merge and arg_names:
|
||||
for attrs, name_ in zip(arg_attrs[num_dim_vars + num_tokens:], arg_names):
|
||||
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
|
||||
|
||||
func_op.arg_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
|
||||
|
||||
|
@ -3046,6 +3046,10 @@ def lower_sharding_computation(
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
|
||||
arg_names = None if arg_info is None else [
|
||||
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)
|
||||
if i in kept_var_idx]
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -3058,7 +3062,8 @@ def lower_sharding_computation(
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_op_shardings,
|
||||
result_shardings=out_op_shardings)
|
||||
result_shardings=out_op_shardings,
|
||||
arg_names=arg_names)
|
||||
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import logging
|
||||
import operator
|
||||
@ -535,3 +536,15 @@ def use_cpp_method(is_enabled=True):
|
||||
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def fun_sourceinfo(fun: Callable) -> Optional[str]:
|
||||
while isinstance(fun, partial):
|
||||
fun = fun.func
|
||||
fun = inspect.unwrap(fun)
|
||||
try:
|
||||
filename = fun.__code__.co_filename
|
||||
lineno = fun.__code__.co_firstlineno
|
||||
return f"{fun.__name__} at {filename}:{lineno}"
|
||||
except AttributeError:
|
||||
return None
|
||||
|
@ -41,12 +41,13 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
|
||||
raise_to_shaped, Atom, JaxprEqn, Primitive,
|
||||
ShapedArray, DShapedArray, mapped_aval,
|
||||
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent)
|
||||
InputType, OutputType, get_referent, DebugInfo)
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
tree_leaves)
|
||||
KeyPath, _generate_key_paths)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache)
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
fun_sourceinfo)
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -951,7 +952,7 @@ def tracers_to_jaxpr(
|
||||
outvars = map(get_atom, out_tracers) # type: ignore[arg-type]
|
||||
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
|
||||
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
|
||||
outvars, eqns, jaxpr_effects)
|
||||
outvars, eqns, jaxpr_effects, None)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
||||
return jaxpr, const_vals, env_vals
|
||||
@ -963,7 +964,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||
lifted_jaxpr = Jaxpr(constvars=(),
|
||||
invars=jaxpr.constvars + jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||
effects=jaxpr.effects)
|
||||
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
|
||||
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
@ -976,7 +977,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
|
||||
constvars, invars = split_list(jaxpr.invars, [n])
|
||||
lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||
effects=jaxpr.effects)
|
||||
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
|
||||
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
@ -1382,7 +1383,8 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
|
||||
eqns = new_eqns[::-1]
|
||||
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
|
||||
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects)
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects,
|
||||
jaxpr.debug_info)
|
||||
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
|
||||
|
||||
return new_jaxpr, used_inputs
|
||||
@ -1486,12 +1488,22 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
if dbg is None:
|
||||
return ""
|
||||
|
||||
origin = (f"The error occurred while tracing the function {dbg.func_src_info} "
|
||||
f"for {dbg.traced_for}. ")
|
||||
if invar_pos:
|
||||
origin = ("The error occurred while tracing the function "
|
||||
f"{fun_sourceinfo(dbg.fn)} for {dbg.traced_for}. ")
|
||||
arg_info = arg_info_all(dbg)
|
||||
if invar_pos and arg_info:
|
||||
arg_info = [arg_info[i] for i in invar_pos]
|
||||
arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info]
|
||||
if len(arg_names) == 1:
|
||||
arg_info_str = f"the argument {arg_names[0]}"
|
||||
elif len(arg_names) == 2:
|
||||
arg_info_str = f"the arguments {arg_names[0]} and {arg_names[1]}"
|
||||
else:
|
||||
*rest, last = arg_names
|
||||
arg_info_str = f"the arguments {', '.join(rest)}, and {last}"
|
||||
origin += ("This concrete value was not available in Python because it "
|
||||
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
|
||||
f"of {dbg.arg_info(invar_pos)}.")
|
||||
f"of {arg_info_str}.")
|
||||
elif progenitor_eqns:
|
||||
msts = [" operation "
|
||||
f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n"
|
||||
@ -1561,7 +1573,8 @@ class JaxprStackFrame:
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
|
||||
self.eqns)
|
||||
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
|
||||
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects,
|
||||
self.debug_info)
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
return jaxpr, constvals
|
||||
@ -1574,7 +1587,7 @@ class JaxprStackFrame:
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
|
||||
self.eqns)
|
||||
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
|
||||
jaxpr_effects)
|
||||
jaxpr_effects, self.debug_info)
|
||||
# We can't run check_jaxpr until after we normalize.
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
@ -1638,7 +1651,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
|
||||
jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars,
|
||||
new_eqns)
|
||||
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns,
|
||||
jaxpr_effects)
|
||||
jaxpr_effects, jaxpr.debug_info)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
|
||||
@ -1688,7 +1701,7 @@ def _inline_literals(jaxpr, constvals):
|
||||
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
|
||||
new_eqns)
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
|
||||
jaxpr_effects)
|
||||
jaxpr_effects, jaxpr.debug_info)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
@ -1828,7 +1841,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
|
||||
f, self.main, reduced_in_avals,
|
||||
debug_info=debug_info_final(f, map_primitive.name))
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -1962,65 +1976,36 @@ def _memoize(thunk):
|
||||
return memoized
|
||||
|
||||
|
||||
class DebugInfo(NamedTuple):
|
||||
func_src_info: str
|
||||
traced_for: str
|
||||
arg_info: Callable[[int], str]
|
||||
|
||||
|
||||
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
|
||||
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
|
||||
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
|
||||
|
||||
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
|
||||
traced_for: str) -> DebugInfo:
|
||||
func_src_info = fun_sourceinfo(fn)
|
||||
if in_tree is not None:
|
||||
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
|
||||
else:
|
||||
arg_info = arg_info_flattened # type: ignore
|
||||
return DebugInfo(func_src_info, traced_for, arg_info)
|
||||
return DebugInfo(ref(fn), in_tree, has_kwargs, traced_for)
|
||||
|
||||
def fun_sourceinfo(fun: Callable):
|
||||
while isinstance(fun, functools.partial):
|
||||
fun = fun.func
|
||||
fun = inspect.unwrap(fun)
|
||||
try:
|
||||
filename = fun.__code__.co_filename
|
||||
lineno = fun.__code__.co_firstlineno
|
||||
line_info = f"{fun.__name__} at {filename}:{lineno}"
|
||||
return line_info
|
||||
except AttributeError:
|
||||
return "<unknown>"
|
||||
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
|
||||
"Make a DebugInfo from data available to final-style primitives like pmap."
|
||||
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
|
||||
return DebugInfo(ref(fn.f), in_tree, has_kwargs, traced_for)
|
||||
|
||||
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
|
||||
flat_pos: List[int]) -> str:
|
||||
dummy_args = [False] * in_tree.num_leaves
|
||||
for i in flat_pos: dummy_args[i] = True
|
||||
if has_kwargs:
|
||||
args, kwargs = tree_unflatten(in_tree, dummy_args)
|
||||
else:
|
||||
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
|
||||
def arg_info(dbg: DebugInfo, flat_pos: int) -> Optional[Tuple[str, KeyPath]]:
|
||||
infos = arg_info_all(dbg)
|
||||
return None if infos is None else infos[flat_pos]
|
||||
|
||||
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
|
||||
ba = None if dbg.in_tree is None else sig_info(dbg)
|
||||
if ba is None: return None
|
||||
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
|
||||
for key_path, _ in _generate_key_paths(dummy_arg)]
|
||||
|
||||
def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
|
||||
if dbg.in_tree is None or dbg.fn is None: return None
|
||||
try:
|
||||
ba = inspect.signature(fn).bind(*args, **kwargs)
|
||||
dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves)
|
||||
except:
|
||||
return None
|
||||
args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {})
|
||||
try:
|
||||
return inspect.signature(dbg.fn).bind(*args, **kwargs)
|
||||
except (TypeError, ValueError):
|
||||
return arg_info_flattened(flat_pos)
|
||||
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
|
||||
if any(tree_leaves(x))]
|
||||
if len(arg_names) == 1:
|
||||
return f"the argument {arg_names[0]}"
|
||||
elif len(arg_names) == 2:
|
||||
return f"the arguments {arg_names[0]} and {arg_names[1]}"
|
||||
else:
|
||||
*rest, last = arg_names
|
||||
return f"the arguments {', '.join(rest)}, and {last}"
|
||||
|
||||
def arg_info_flattened(flat_pos: List[int]) -> str:
|
||||
if len(flat_pos) > 1:
|
||||
return f"the argument passed at flattened positions {flat_pos}"
|
||||
else:
|
||||
return f"the argument passed at flattened position {flat_pos[0]}"
|
||||
|
||||
return None
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
@ -2252,7 +2237,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
|
||||
out_type = tuple(zip(out_avals, kept_outs))
|
||||
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
|
||||
jaxpr.effects)
|
||||
jaxpr.effects, jaxpr.debug_info)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
return new_jaxpr, out_type
|
||||
|
||||
|
@ -1145,6 +1145,22 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
self.assertIsNotNone(g.runtime_executable())
|
||||
|
||||
def test_jit_lower_arg_info(self):
|
||||
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
|
||||
raise unittest.SkipTest("test only applies after jit-pjit api merge")
|
||||
|
||||
def f(x, y, *args, **kwargs):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
|
||||
mhlo_str = str(lowered.compiler_ir('mhlo'))
|
||||
self.assertNotIn("\"x\"", mhlo_str)
|
||||
self.assertIn("y['hi']", mhlo_str)
|
||||
self.assertNotIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
self.assertIn("kwargs['z']", mhlo_str)
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
|
||||
def test_jit_enum_as_dict_keys_fails(self):
|
||||
class E(enum.Enum):
|
||||
A = 0
|
||||
@ -3389,7 +3405,7 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return 0
|
||||
|
||||
msg = r"on the value of the argument 'x'"
|
||||
msg = r"on the value of the argument x"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1)
|
||||
|
||||
@ -3401,7 +3417,7 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the arguments 'x' and 'y'"
|
||||
msg = r"on the values of the arguments x and y"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2)
|
||||
|
||||
@ -3413,7 +3429,7 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the arguments 'x' and 'z'"
|
||||
msg = r"on the values of the arguments x and z"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2, 3)
|
||||
|
||||
@ -3432,7 +3448,7 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the argument 'args'"
|
||||
msg = r"on the values of the arguments args"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2, 3)
|
||||
|
||||
@ -3445,7 +3461,7 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the argument 'kwargs'"
|
||||
msg = r"on the values of the arguments kwargs"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(x=1, y=2, z=3)
|
||||
|
||||
@ -3458,7 +3474,7 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the value of the argument 'xy'"
|
||||
msg = r"on the value of the argument xy"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f((1, 2), z=3)
|
||||
|
||||
@ -3491,7 +3507,7 @@ class APITest(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
return f(x, True)
|
||||
|
||||
msg = r"on the value of the argument 'y'"
|
||||
msg = r"on the value of the argument y"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
g(1)
|
||||
|
||||
@ -5184,11 +5200,11 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(res[0][0].shape, (1,))
|
||||
self.assertEqual(res[0][1], "from a constant")
|
||||
self.assertEqual(res[1][0].shape, ())
|
||||
self.assertEqual(res[1][1], "from the argument 'x'")
|
||||
self.assertEqual(res[1][1], "from the argument x[0]")
|
||||
self.assertEqual(res[2][0].shape, ())
|
||||
self.assertEqual(res[2][1], "from the argument 'x'")
|
||||
self.assertEqual(res[2][1], "from the argument x[1]")
|
||||
self.assertEqual(res[3][0].shape, ())
|
||||
self.assertEqual(res[3][1], "from the argument 'y'")
|
||||
self.assertEqual(res[3][1], "from the argument y")
|
||||
self.assertEqual(res[4][0].shape, ())
|
||||
self.assertStartsWith(res[4][1], "named 'z'")
|
||||
self.assertEqual(res[5][0].shape, ())
|
||||
|
Loading…
x
Reference in New Issue
Block a user