mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[better_errors] Add debug info to more Jaxprs and WrappedFun (step 1)
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry non-None debug info. We change `lu.wrap_init` to construct the result paths thunk whenever it is passed a `debug_info`. The goal is to make sure that all `WrappedFun` have a debug info with result paths support. We change some calling conventions for internal functions to not pass along a separate debug_info if we have a `WrappedFun` or a `Jaxpr`. We obtain several improvements in presence of debug infos in debug_info_test.py
This commit is contained in:
parent
7e353913f2
commit
d12aead696
@ -420,9 +420,9 @@ def _trace_to_jaxpr(fun: Callable,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: core.DebugInfo
|
||||
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
|
||||
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
|
||||
try:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
except core.ConcretizationTypeError as e:
|
||||
msg, = e.args
|
||||
if 'for checkpoint' in msg:
|
||||
@ -699,7 +699,8 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
|
||||
assert next(ins_iter, None) is None
|
||||
with source_info_util.extend_name_stack('rematted_computation'):
|
||||
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info),
|
||||
in_pvals, False)
|
||||
|
||||
# Transpose the linear jaxpr (which only has linear inputs).
|
||||
out_cts_iter = iter(out_cts_flat)
|
||||
|
@ -61,7 +61,7 @@ from jax._src.api_util import (
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, debug_info,
|
||||
result_paths, flat_out_axes)
|
||||
flat_out_axes)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -1430,7 +1430,8 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
|
||||
"pmap", fun, args, kwargs,
|
||||
static_argnums=static_broadcasted_tuple)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f = lu.wrap_init(fun, debug_info=dbg)
|
||||
del dbg
|
||||
if static_broadcasted_tuple:
|
||||
if max(static_broadcasted_tuple) >= len(args):
|
||||
raise ValueError(
|
||||
@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
|
||||
raise ValueError(msg) from None
|
||||
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
|
||||
|
||||
f, res_paths = result_paths(f)
|
||||
dbg = dbg.add_result_paths(res_paths)
|
||||
f = lu.add_debug_info(f, dbg)
|
||||
f, out_axes_thunk = flat_out_axes(f, out_axes)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
|
||||
|
@ -590,7 +590,7 @@ def debug_info(
|
||||
static_argnums: tuple[int, ...] = (),
|
||||
static_argnames: tuple[str, ...] = (),
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing.
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing?
|
||||
sourceinfo: str | None = None,
|
||||
signature: inspect.Signature | None = None,
|
||||
) -> core.DebugInfo:
|
||||
@ -674,29 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
|
||||
arg_names = args_arg_names + kwargs_arg_names
|
||||
return arg_names
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def result_paths(_fun, _store, *args, **kwargs):
|
||||
"linear_util transform to get output pytree paths of pre-flattened function."
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
|
||||
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
debug: core.DebugInfo | None,
|
||||
result_paths: tuple[str, ...] | None = None,
|
||||
) -> core.Jaxpr:
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
if debug is None:
|
||||
return jaxpr
|
||||
# TODO(necula): re-enable this safety check
|
||||
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
|
||||
if result_paths is not None:
|
||||
debug = debug._replace(result_paths=tuple(result_paths))
|
||||
else:
|
||||
debug = debug.resolve_result_paths()
|
||||
return jaxpr.replace(debug_info=debug)
|
||||
|
||||
def hoist_obj_attrs(f, flat_args):
|
||||
idxs, objs, flat_args_ = [], [], []
|
||||
for i, x in enumerate(flat_args):
|
||||
@ -721,7 +698,7 @@ def register_class_with_attrs(t: type) -> None:
|
||||
_class_with_attrs: set[type] = set()
|
||||
|
||||
# TODO(mattjj): make this function faster
|
||||
def _check_no_aliased_ref_args(dbg, avals, args):
|
||||
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
|
||||
assert config.mutable_array_checks.value
|
||||
refs: dict[int, int] = {}
|
||||
for i, (a, x) in enumerate(zip(avals, args)):
|
||||
@ -735,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
|
||||
if dbg else
|
||||
f"at both flat index {dup_idx} and flat index {i}") from None
|
||||
|
||||
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
|
||||
assert config.mutable_array_checks.value
|
||||
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
||||
if isinstance(core.get_aval(c), AbstractRef)}
|
||||
@ -746,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
|
||||
f"array reference of type {a.str_short()} was both closed over and "
|
||||
f"passed as the argument "
|
||||
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
|
||||
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
|
||||
|
@ -1206,7 +1206,7 @@ def checkify(f: Callable[..., Out],
|
||||
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
|
||||
debug_info=debug),
|
||||
in_tree)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
|
||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||
# checkify:
|
||||
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
|
||||
|
@ -2369,7 +2369,8 @@ class CallPrimitive(Primitive):
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
|
||||
jaxpr, ())
|
||||
if config.dynamic_shapes.value:
|
||||
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
|
||||
return [subfun], new_params
|
||||
@ -2402,7 +2403,7 @@ class MapPrimitive(Primitive):
|
||||
map_primitive = True
|
||||
|
||||
def bind_with_trace(self, trace, fun_and_args, params):
|
||||
fun = fun_and_args[0]
|
||||
fun: lu.WrappedFun = fun_and_args[0]
|
||||
args = fun_and_args[1:]
|
||||
assert len(params['in_axes']) == len(args)
|
||||
return trace.process_map(self, fun, args, params)
|
||||
@ -2412,8 +2413,9 @@ class MapPrimitive(Primitive):
|
||||
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
||||
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info), jaxpr, ())
|
||||
axes = new_params.pop('out_axes')
|
||||
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
|
@ -153,7 +153,7 @@ class custom_vmap:
|
||||
lu.wrap_init(self.fun, debug_info=debug),
|
||||
in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
assert self.vmap_rule is not None
|
||||
|
@ -147,11 +147,11 @@ class custom_dce:
|
||||
)
|
||||
static_args = [args[i] for i in self.static_argnums]
|
||||
dce_rule = api_util.prepend_static_args(
|
||||
lu.wrap_init(self.dce_rule), static_args
|
||||
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
|
||||
)
|
||||
else:
|
||||
fun = lu.wrap_init(self.fun, debug_info=debug)
|
||||
dce_rule = lu.wrap_init(self.dce_rule)
|
||||
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
|
||||
dyn_args = args
|
||||
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
@ -176,7 +176,7 @@ class custom_dce:
|
||||
)
|
||||
assert self.dce_rule is not None
|
||||
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_rule, in_avals, debug_rule
|
||||
flat_rule, in_avals
|
||||
)
|
||||
|
||||
# This second round of DCE is used to work out which inputs are actually
|
||||
@ -191,7 +191,7 @@ class custom_dce:
|
||||
|
||||
return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins
|
||||
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
out_avals = closed_call.out_avals
|
||||
out_flat = custom_dce_p.bind(
|
||||
@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
|
||||
# that most users of this API would compose this with a custom_jvp or
|
||||
# custom_vjp, which makes this less urgent.
|
||||
out = core.call_p.bind(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents
|
||||
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
|
||||
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
|
||||
)
|
||||
|
||||
out_primals, out_tangents = util.split_list(out, [len(out_nz)])
|
||||
|
@ -485,13 +485,13 @@ class custom_partitioning:
|
||||
_check_for_tracers(static_args)
|
||||
else:
|
||||
static_args = []
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
|
@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
|
||||
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
|
||||
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
|
||||
residual_avals = map(get_aval, consts)
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
@ -166,16 +166,17 @@ def _linearize_jaxpr(
|
||||
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
|
||||
del lin_trace, ans, tracers, new_arg
|
||||
|
||||
debug_info = jaxpr.jaxpr.debug_info
|
||||
nzs_out = [type(t) is not Zero for t in out_tangents]
|
||||
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
|
||||
for (nz, t) in zip(nzs_out, out_tangents) if nz)
|
||||
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
|
||||
tangent_trace.invalidate()
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
residuals_and_primals = (*tangent_consts, *out_primals)
|
||||
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
|
||||
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
|
||||
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
|
||||
primal_trace.invalidate()
|
||||
num_residuals = len(tangent_consts)
|
||||
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
|
||||
@ -207,7 +208,7 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
out_nzs = [type(t) is not Zero for t in out_tangents]
|
||||
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
|
||||
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
|
||||
tangent_trace.invalidate()
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
|
||||
pe.PartialVal.known(zeros_like_aval(t.aval))
|
||||
@ -1019,12 +1020,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
|
||||
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
|
||||
assert len(jaxpr.in_avals) == len(nonzeros)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
debug_info = jaxpr.jaxpr.debug_info
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
|
||||
nonzeros)
|
||||
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
|
||||
f_jvp, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
|
@ -760,7 +760,8 @@ def _batch_jaxpr2(
|
||||
axis_data,
|
||||
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
|
||||
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
|
||||
debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
||||
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
||||
in_axes2, avals_in = unzip2([
|
||||
|
@ -1393,7 +1393,6 @@ def lower_jaxpr_to_fun(
|
||||
MLIR func op
|
||||
"""
|
||||
util.test_event("lower_jaxpr_to_fun", name)
|
||||
|
||||
# The first dimension variable may be the platform index
|
||||
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
||||
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
|
||||
|
@ -501,7 +501,7 @@ call_param_updaters[core.closed_call_p] = _closed_call_param_updater
|
||||
|
||||
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
|
||||
_, avals_out, _, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, params), avals, debug_info)
|
||||
lu.wrap_init(fun, params, debug_info=debug_info), avals)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
return avals_out
|
||||
|
||||
@ -589,7 +589,7 @@ def trace_to_subjaxpr_nounits(
|
||||
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits2(
|
||||
f,
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
@ -950,7 +950,9 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
|
||||
return [*known_vals_out, *residuals]
|
||||
|
||||
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
|
||||
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
|
||||
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, debug_info=f.debug_info),
|
||||
known_avals)
|
||||
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
|
||||
|
||||
# check jaxpr_known and jaxpr_unknown in isolation
|
||||
@ -1124,7 +1126,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
|
||||
known_outvars, known_eqns)
|
||||
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
|
||||
known_eqns, known_effects)
|
||||
known_eqns, known_effects, jaxpr.debug_info)
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr_known)
|
||||
|
||||
_, ins_staged = partition_list(in_inst, jaxpr.invars)
|
||||
@ -1336,8 +1338,7 @@ def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
jaxpr.debug_info.arg_names,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
|
||||
used_outputs) if b))
|
||||
jaxpr.debug_info.filter_result_paths(used_outputs))
|
||||
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
return new_jaxpr
|
||||
@ -1424,10 +1425,8 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)),
|
||||
used_inputs) if b),
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
|
||||
used_outputs) if b))
|
||||
jaxpr.debug_info.filter_arg_names(used_inputs),
|
||||
jaxpr.debug_info.filter_result_paths(used_outputs))
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
|
||||
@ -1644,7 +1643,9 @@ class JaxprStackFrame:
|
||||
def add_eqn(self, eqn: core.JaxprEqn):
|
||||
self.eqns.append(eqn)
|
||||
|
||||
def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer]
|
||||
def to_jaxpr(self, trace: DynamicJaxprTrace,
|
||||
out_tracers: Sequence[Tracer],
|
||||
debug_info: core.DebugInfo | None,
|
||||
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
@ -1657,7 +1658,8 @@ class JaxprStackFrame:
|
||||
outvars = state_outvars + explicit_outvars
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
|
||||
debug_info)
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
|
||||
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
|
||||
@ -1950,7 +1952,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||
f, reduced_in_avals, f.debug_info)
|
||||
f, reduced_in_avals)
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -2074,8 +2076,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
|
||||
return self.frame.to_jaxpr(self, out_tracers)
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer],
|
||||
debug_info: core.DebugInfo | None):
|
||||
return self.frame.to_jaxpr(self, out_tracers, debug_info)
|
||||
|
||||
|
||||
custom_staging_rules: dict[Primitive, Callable] = {}
|
||||
@ -2116,14 +2119,12 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
def trace_to_jaxpr_dynamic(
|
||||
fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: core.DebugInfo | None = None,
|
||||
*,
|
||||
keep_inputs: list[bool] | None = None,
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
trace = DynamicJaxprTrace(debug_info)
|
||||
trace = DynamicJaxprTrace(fun.debug_info)
|
||||
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
|
||||
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
|
||||
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
||||
@ -2131,8 +2132,8 @@ def trace_to_jaxpr_dynamic(
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
_check_no_returned_refs(debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
|
||||
_check_no_returned_refs(fun.debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info)
|
||||
del trace, fun, in_tracers, out_tracers, ans
|
||||
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
@ -2160,7 +2161,7 @@ def _check_no_returned_refs(
|
||||
origin_info = ('\n\nThe returned mutable array was created on line '
|
||||
f'{source_info_util.summarize(eqn.source_info)}.')
|
||||
elif v in frame.invars:
|
||||
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
|
||||
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore
|
||||
origin_info = ('\n\nThe returned mutable array was passed in as the '
|
||||
f'argument {arg_name}.')
|
||||
else:
|
||||
@ -2172,10 +2173,10 @@ def _check_no_returned_refs(
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None
|
||||
fun: lu.WrappedFun,
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
|
||||
trace = DynamicJaxprTrace(debug_info)
|
||||
trace = DynamicJaxprTrace(fun.debug_info)
|
||||
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
|
||||
in_avals, keep_inputs = unzip2(fun.in_type)
|
||||
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
|
||||
|
@ -33,7 +33,6 @@ import numpy as np
|
||||
import jax
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -652,7 +651,6 @@ class ParallelCallableInfo:
|
||||
in_axes: Iterable[int | None]
|
||||
out_axes_thunk: Callable[[], Sequence[int | None]]
|
||||
avals: Sequence[core.AbstractValue]
|
||||
debug_info: core.DebugInfo | None
|
||||
|
||||
@cached_property
|
||||
def local_devices(self):
|
||||
@ -723,8 +721,7 @@ def stage_parallel_callable(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||
fun, sharded_avals, pci.debug_info)
|
||||
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
|
||||
fun, sharded_avals)
|
||||
|
||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||
len(out_sharded_avals), len(pci.out_axes))
|
||||
@ -758,7 +755,7 @@ def get_pmap_jaxpr(
|
||||
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals, fun.debug_info)
|
||||
in_axes, out_axes_thunk, avals)
|
||||
with core.extend_axis_env_nd([(axis_name, axis_size)]):
|
||||
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
||||
@ -992,7 +989,7 @@ class UnloadedPmapExecutable:
|
||||
|
||||
return PmapExecutable(
|
||||
self.compiled, self.build_execute_fun, fingerprint,
|
||||
self.local_input_avals, self.jaxpr_debug_info, self)
|
||||
self.local_input_avals, self)
|
||||
|
||||
@staticmethod
|
||||
def from_hlo(hlo: ir.Module,
|
||||
@ -1119,24 +1116,23 @@ class UnloadedPmapExecutable:
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
||||
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
||||
"_unloaded_executable"]
|
||||
"fingerprint", "in_avals", "_unloaded_executable"]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
|
||||
in_avals, jaxpr_debug_info, unloaded_executable):
|
||||
in_avals,
|
||||
unloaded_executable: UnloadedPmapExecutable):
|
||||
self.xla_executable = xla_executable
|
||||
self._unsafe_call = None
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
self.fingerprint = fingerprint
|
||||
self.in_avals = in_avals
|
||||
self._jaxpr_debug_info = jaxpr_debug_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
def unsafe_call(self) -> Callable[..., Any]:
|
||||
if self._unsafe_call is None:
|
||||
self._unsafe_call = self.build_unsafe_call()
|
||||
return self._unsafe_call
|
||||
return self._unsafe_call # type: ignore
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
@ -1147,7 +1143,8 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
def call(self, *args):
|
||||
# TODO(frostig): do we need to check sharding and sharded avals?
|
||||
arg_avals = map(core.abstractify, args)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals,
|
||||
self._unloaded_executable.jaxpr_debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
|
||||
@ -3206,7 +3203,7 @@ def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
f"but called with {len(arg_avals)}")
|
||||
|
||||
if jaxpr_debug_info is not None:
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
|
||||
else:
|
||||
num_args = len(ref_avals)
|
||||
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
|
||||
|
@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable,
|
||||
lu.wrap_init(fun, debug_info=debug_info),
|
||||
in_tree)
|
||||
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_fun, in_avals, debug_info)
|
||||
wrapped_fun, in_avals)
|
||||
return jaxpr, consts, out_tree(), attrs_tracked
|
||||
|
||||
@weakref_lru_cache
|
||||
|
@ -746,8 +746,9 @@ def _trace_composite_to_jaxpr(fun: Callable,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
name: str,
|
||||
debug_info: core.DebugInfo):
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug_info), in_tree)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError(
|
||||
"Found a JAX Tracer as a constant in the decomposition for the "
|
||||
|
@ -63,7 +63,7 @@ data must be immutable, because it will be stored in function memoization tables
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, NamedTuple
|
||||
import weakref
|
||||
@ -71,6 +71,7 @@ import weakref
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.tree_util import keystr, generate_key_paths
|
||||
from jax._src.util import curry, cache_clearing_funs, HashableFunction
|
||||
|
||||
|
||||
@ -275,13 +276,6 @@ class DebugInfo(NamedTuple):
|
||||
# e.g. ('[0]', '[1]', ...)
|
||||
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
|
||||
|
||||
def add_result_paths(self,
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]]
|
||||
) -> DebugInfo:
|
||||
assert self.result_paths is None
|
||||
return self._replace(result_paths=HashableFunction(result_paths_thunk,
|
||||
closure=()))
|
||||
|
||||
def resolve_result_paths(self) -> DebugInfo:
|
||||
"""Return a debug info with resolved result paths."""
|
||||
if callable(self.result_paths):
|
||||
@ -296,6 +290,10 @@ class DebugInfo(NamedTuple):
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]:
|
||||
"""Keep only the arg_names for which `keep` is True."""
|
||||
return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b)
|
||||
|
||||
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
|
||||
"""Get the result paths with a safety check."""
|
||||
assert not callable(self.result_paths), self
|
||||
@ -305,15 +303,34 @@ class DebugInfo(NamedTuple):
|
||||
# TODO(necula): this should not happen
|
||||
return ("",) * expected
|
||||
|
||||
def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
|
||||
"""Keep only the result_paths for which `keep` is True."""
|
||||
assert not callable(self.result_paths), self
|
||||
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)
|
||||
|
||||
|
||||
def wrap_init(f: Callable, params=None, *,
|
||||
debug_info: DebugInfo | None = None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params_dict = {} if params is None else params
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
|
||||
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
|
||||
if debug_info:
|
||||
if debug_info.result_paths is None:
|
||||
fun, result_paths_thunk = _get_result_paths_thunk(fun)
|
||||
debug_info = debug_info._replace(
|
||||
result_paths=HashableFunction(result_paths_thunk, closure=()))
|
||||
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
|
||||
fun.params, fun.in_type, debug_info)
|
||||
return fun
|
||||
|
||||
|
||||
@transformation_with_aux2
|
||||
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
@ -350,16 +367,9 @@ def _check_input_type(in_type: core.InputType) -> None:
|
||||
provided[d.val] = True
|
||||
assert all(provided)
|
||||
|
||||
def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None
|
||||
) -> WrappedFun:
|
||||
"""Produce a new WrappedFun with debug_info attached."""
|
||||
assert f.debug_info is None
|
||||
if debug_info is None:
|
||||
return f
|
||||
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info)
|
||||
|
||||
|
||||
def cache(call: Callable, *, explain: Callable | None = None):
|
||||
def cache(call: Callable, *,
|
||||
explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None):
|
||||
"""Memoization decorator for functions taking a WrappedFun as first argument.
|
||||
|
||||
Args:
|
||||
@ -367,6 +377,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
underlying transforms and params on the WrappedFun are used as part of the
|
||||
memoization cache key.
|
||||
|
||||
explain: a function that is invoked upon cache misses to log an explanation
|
||||
of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`.
|
||||
|
||||
Returns:
|
||||
A memoized version of ``call``.
|
||||
"""
|
||||
@ -382,7 +395,7 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
else:
|
||||
ans = call(fun, *args)
|
||||
if explain and config.explain_cache_misses.value:
|
||||
explain(fun.f, cache is new_cache, cache, key)
|
||||
explain(fun, cache is new_cache, cache, key)
|
||||
cache[key] = (ans, fun.stores)
|
||||
|
||||
return ans
|
||||
|
@ -423,7 +423,7 @@ class BlockSpec:
|
||||
)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_index_map_fun, index_map_avals, debug_info=debug
|
||||
flat_index_map_fun, index_map_avals
|
||||
)
|
||||
mapped_block_shape = tuple(mapped if s is None else s for s in block_shape)
|
||||
if len(out_avals) != len(block_shape):
|
||||
|
@ -101,12 +101,12 @@ def _pallas_call_jvp_rule(
|
||||
primals,
|
||||
tangents,
|
||||
*,
|
||||
jaxpr,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name_and_src_info,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping,
|
||||
debug,
|
||||
interpret,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
@ -1098,13 +1098,14 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
|
||||
*error_memref_aval, *output_aval, *scratch_aval]
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
||||
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
|
||||
retrace_in_avals, {})
|
||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree)
|
||||
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
|
||||
wrapped_kernel_with_err, jaxpr_flat_avals)
|
||||
|
||||
# Prepare pallas_call inputs. We need to create new block specs
|
||||
# for the new error inputs and outputs.
|
||||
@ -1161,16 +1162,16 @@ def _trace_kernel_to_jaxpr(
|
||||
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
|
||||
indexer: bool = False,
|
||||
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
|
||||
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
|
||||
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
|
||||
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), kernel_in_tree)
|
||||
lu.wrap_init(fun, debug_info=debug), kernel_in_tree)
|
||||
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
||||
wrapped_kernel_fun, kernel_in_transforms
|
||||
)
|
||||
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
|
||||
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
kernel_avals)
|
||||
if consts:
|
||||
consts_avals = [jax_core.get_aval(c) for c in consts]
|
||||
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
|
||||
|
@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
|
||||
argnames_partial_except, debug_info,
|
||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -567,9 +567,7 @@ def _infer_params_impl(
|
||||
|
||||
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, res_paths = result_paths(f)
|
||||
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
|
||||
f = lu.wrap_init(fun, debug_info=dbg)
|
||||
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
|
||||
del args
|
||||
|
||||
@ -618,7 +616,7 @@ def _infer_params_impl(
|
||||
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
in_shardings_treedef, in_shardings_leaves,
|
||||
ji.in_layouts_treedef, ji.in_layouts_leaves,
|
||||
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
|
||||
in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs)
|
||||
|
||||
attr_token = _attr_token(flat_fun, in_type)
|
||||
|
||||
@ -627,8 +625,7 @@ def _infer_params_impl(
|
||||
if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh())
|
||||
with mesh_lib.set_abstract_mesh(abstract_mesh):
|
||||
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
HashableFunction(res_paths, closure=()),
|
||||
flat_fun, in_type, attr_token,
|
||||
IgnoreKey(ji.inline))
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
|
||||
@ -1171,17 +1168,18 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
callsites: set[str] = set()
|
||||
|
||||
def explain_tracing_cache_miss(
|
||||
f: Callable, unseen_f: bool, cache: dict, key: tuple):
|
||||
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
|
||||
if config.check_tracer_leaks.value: return
|
||||
|
||||
def unpack(key):
|
||||
transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key
|
||||
transforms, (), _, (in_type, _, inline), *_, ctx = key
|
||||
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
|
||||
_, (_, (in_tree,)), *_ = transforms
|
||||
return in_tree, in_type, debug_info, inline.val, ctx
|
||||
in_tree, in_type, debug_info, inline, ctx = unpack(key)
|
||||
return in_tree, in_type, inline.val, ctx
|
||||
in_tree, in_type, inline, ctx = unpack(key)
|
||||
if inline: return
|
||||
|
||||
debug_info = fun.debug_info
|
||||
msg: list[str] = []
|
||||
p = msg.append
|
||||
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
|
||||
@ -1190,7 +1188,7 @@ def explain_tracing_cache_miss(
|
||||
p(f"TRACING CACHE MISS at {callsite} because:")
|
||||
|
||||
# have we seen this function before at all?
|
||||
fun_name = getattr(f, '__qualname__', f)
|
||||
fun_name = getattr(fun.f, '__qualname__', fun.f)
|
||||
if debug_info is not None and debug_info.func_src_info:
|
||||
# TODO(necula): clean up the extraction of the source info
|
||||
_, *rest = debug_info.func_src_info.split(' at ')
|
||||
@ -1198,7 +1196,7 @@ def explain_tracing_cache_miss(
|
||||
else:
|
||||
src_info = ''
|
||||
if unseen_f:
|
||||
p(f" never seen function:\n {fun_name} id={id(f)}{src_info}")
|
||||
p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}")
|
||||
if callsite in callsites:
|
||||
p(" but seen another function defined on the same line; maybe the function is\n"
|
||||
" being re-defined repeatedly, preventing caching?")
|
||||
@ -1263,7 +1261,7 @@ def explain_tracing_cache_miss(
|
||||
# have we never seen these input types (eg shapes, dtypes) before?
|
||||
types_match = [k for k in trees_match if k[1] == in_type]
|
||||
if not types_match:
|
||||
if len(in_type) < 5:
|
||||
if len(in_type) < 5 and debug_info is not None:
|
||||
in_type_str = ':\n {}'.format(', '.join(
|
||||
f'{n}: {ty.str_short(short_dtypes=True)}'
|
||||
for n, ty in zip(debug_info.arg_names, in_type)))
|
||||
@ -1275,7 +1273,12 @@ def explain_tracing_cache_miss(
|
||||
num_mismatch = sum(map(op.ne, closest_ty, in_type))
|
||||
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
|
||||
add_weak_type_hint = False
|
||||
for name, ty1, ty2 in zip(debug_info.arg_names, closest_ty, in_type):
|
||||
if debug_info:
|
||||
arg_names = debug_info.safe_arg_names(len(in_type))
|
||||
else:
|
||||
arg_names = (None,) * len(in_type)
|
||||
|
||||
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
|
||||
if ty1 != ty2:
|
||||
if type(ty1) == type(ty2) == core.ShapedArray:
|
||||
s1, s2 = ty1.str_short(True), ty2.str_short(True)
|
||||
@ -1302,8 +1305,6 @@ def _create_pjit_jaxpr(
|
||||
fun: lu.WrappedFun,
|
||||
in_type: core.InputType | Sequence[core.AbstractValue],
|
||||
attr_data: int,
|
||||
debug_info: core.DebugInfo,
|
||||
result_paths: Callable,
|
||||
ignored_inline: IgnoreKey
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
@ -1317,17 +1318,13 @@ def _create_pjit_jaxpr(
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
if config.dynamic_shapes.value:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
|
||||
lu.annotate(fun, cast(core.InputType, in_type)))
|
||||
attrs_tracked = []
|
||||
else:
|
||||
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_type, debug_info=debug_info)
|
||||
fun, in_type)
|
||||
# assert attr_data is sentinel or attr_data matches attrs_tracked
|
||||
|
||||
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
|
||||
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
|
||||
@ -1928,7 +1925,9 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
|
||||
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
|
||||
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
|
||||
name: str, jaxpr: core.ClosedJaxpr,
|
||||
effects, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts,
|
||||
api_name):
|
||||
mod_ctx = ctx.module_context
|
||||
@ -1959,7 +1958,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
return func
|
||||
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, keep_unused, inline, compiler_options_kvs):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
@ -1987,8 +1987,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
mlir.register_lowering(pjit_p, _pjit_lowering)
|
||||
|
||||
|
||||
def _pjit_batcher(axis_data, vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
def _pjit_batcher(axis_data, vals_in,
|
||||
dims_in: tuple[int, ...],
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
||||
@ -2037,7 +2039,8 @@ batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Sharding | UnspecifiedValue,
|
||||
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
|
||||
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
|
||||
ndim: int):
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
return s
|
||||
hlo_s = s._to_xla_hlo_sharding(ndim)
|
||||
@ -2049,7 +2052,7 @@ def _pjit_batcher_for_sharding(
|
||||
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
|
||||
new_op = hlo_s.to_proto().clone()
|
||||
tad = list(new_op.tile_assignment_dimensions)
|
||||
tad.insert(dim, 1)
|
||||
tad.insert(dim, 1) # type: ignore
|
||||
new_op.tile_assignment_dimensions = tad
|
||||
new_gs = GSPMDSharding(
|
||||
s._device_assignment, new_op,
|
||||
@ -2171,8 +2174,9 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
ad.primitive_linearizations[pjit_p] = _pjit_linearization
|
||||
|
||||
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
*in_tracers,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
name, keep_unused, inline, compiler_options_kvs):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
@ -2191,7 +2195,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
else:
|
||||
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
|
||||
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
|
||||
unknown_outs = tuple(unknown_outs)
|
||||
unknown_outs = tuple(unknown_outs) # type: ignore[assignment]
|
||||
known_outs = tuple(not uk for uk in unknown_outs)
|
||||
num_residuals = len(res_avals)
|
||||
res_shardings = (UNSPECIFIED,) * num_residuals
|
||||
@ -2282,7 +2286,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_out_avals = unknown_jaxpr.out_avals
|
||||
unknown_tracers_out = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) # type: ignore
|
||||
for aval in unknown_out_avals
|
||||
]
|
||||
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
|
||||
|
@ -998,9 +998,10 @@ def _initial_style_jaxpr(fun: Callable,
|
||||
in_tree: api_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: core.DebugInfo):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
|
||||
|
@ -88,7 +88,7 @@ from jax._src.interpreters.partial_eval import (
|
||||
# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes
|
||||
def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa
|
||||
jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic(
|
||||
fun, in_avals, debug_info, keep_inputs=keep_inputs)
|
||||
fun, in_avals, keep_inputs=keep_inputs)
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
|
||||
|
@ -14,9 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from typing import Any
|
||||
@ -74,12 +72,12 @@ def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
|
||||
# Strip the absolute path and the line number but check that it references
|
||||
# this file (to catch errors when the source info points in JAX internals)
|
||||
fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info)
|
||||
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}"
|
||||
arg_names_str = ",".join([str(a) for a in dbg.arg_names])
|
||||
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={arg_names_str}"
|
||||
if isinstance(dbg.result_paths, tuple):
|
||||
if dbg.result_paths:
|
||||
res += f", result_paths={','.join(dbg.result_paths)}"
|
||||
else:
|
||||
res += ", result_paths=<empty>"
|
||||
res += f", result_paths={','.join(dbg.result_paths)}"
|
||||
elif dbg.result_paths is None:
|
||||
res += ", result_paths=<empty>"
|
||||
return res
|
||||
|
||||
|
||||
@ -151,7 +149,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
|
||||
for j in all_jaxprs]
|
||||
|
||||
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos) # JAXPRS
|
||||
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos,
|
||||
"Jaxprs debug_infos") # JAXPRS
|
||||
|
||||
found_tracer_debug_infos = []
|
||||
if tracer_spy is not None:
|
||||
@ -173,7 +172,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
else:
|
||||
found_tracer_debug_infos.append("None")
|
||||
|
||||
self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos) # INSPECTED TRACERS
|
||||
self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos,
|
||||
"Tracer debug_infos") # INSPECTED TRACERS
|
||||
|
||||
if not check_lowering: return
|
||||
# Collect all the lines in all the MLIR modules
|
||||
@ -186,36 +186,34 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
mlir_modules_lines.extend(
|
||||
mlir.module_to_string(mod, enable_debug_info=True).split("\n"))
|
||||
|
||||
expected_and_found = set()
|
||||
expected_and_not_found = set()
|
||||
for exp in expected_lowering_lines:
|
||||
for l in mlir_modules_lines:
|
||||
ok = exp.match(l) if isinstance(exp, re.Pattern) else exp == l
|
||||
if ok:
|
||||
expected_and_found.add(exp)
|
||||
break
|
||||
else:
|
||||
expected_and_not_found.add(exp)
|
||||
|
||||
if expected_and_not_found:
|
||||
msg = "\n".join(mlir_modules_lines)
|
||||
self.assertEmpty(expected_and_not_found, "\nNot found in the MLIR module lines:\n" + msg)
|
||||
self._check_matches(expected_lowering_lines, mlir_modules_lines,
|
||||
"MLIR module lines", report_found_unexpected=False)
|
||||
|
||||
def _check_matches(self,
|
||||
expected: list[str | re.Pattern],
|
||||
found: list[str]):
|
||||
expected_and_found = set()
|
||||
unexpected: set[str] = set()
|
||||
for debug_info in found:
|
||||
for exp_re in expected:
|
||||
ok = exp_re.match(debug_info) if isinstance(exp_re, re.Pattern) else exp_re == debug_info
|
||||
found: list[str],
|
||||
what: str,
|
||||
report_found_unexpected: bool = True):
|
||||
expected_and_found: set[str | re.Pattern] = set()
|
||||
found_and_expected: set[str] = set()
|
||||
for exp_re in expected:
|
||||
for found_line in found:
|
||||
ok = exp_re.match(found_line) if isinstance(exp_re, re.Pattern) else exp_re == found_line
|
||||
if ok:
|
||||
expected_and_found.add(exp_re)
|
||||
break
|
||||
else:
|
||||
unexpected.add(debug_info)
|
||||
self.assertEmpty(unexpected) # found unexpected debug_info
|
||||
self.assertEmpty([e for e in expected if e not in expected_and_found]) # expected element that was not found
|
||||
found_and_expected.add(found_line)
|
||||
|
||||
found_and_unexpected = set(found) - found_and_expected
|
||||
all_found = "\n ".join(found)
|
||||
if report_found_unexpected and found_and_unexpected:
|
||||
unexp_str = "\n ".join(found_and_unexpected)
|
||||
msg = f"Found unexpected {what}:\n {unexp_str}\nAll found {what}:\n {all_found}"
|
||||
self.assertTrue(False, msg)
|
||||
|
||||
if expected_not_found := {e for e in expected if e not in expected_and_found}:
|
||||
exp_str = "\n ".join([str(e) for e in expected_not_found])
|
||||
msg = f"Expected but not found in {what}:\n {exp_str}\nAll found {what}:\n {all_found}"
|
||||
self.assertTrue(False, msg)
|
||||
|
||||
def test_debug_info_basic(self):
|
||||
def my_f(x, y, z, w):
|
||||
@ -634,8 +632,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
3,
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): bad result names
|
||||
'traced_for=jit, fun=my_f, arg_names=a, result_paths=<empty>',
|
||||
'traced_for=jit, fun=my_f, arg_names=a, result_paths=',
|
||||
'traced_for=jit, fun=my_g, arg_names=b, result_paths=',
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
@ -694,7 +691,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
|
||||
"None", # TODO(necula)
|
||||
"None", # TODO(necula) missing debug info
|
||||
],
|
||||
expected_lowering_lines=[
|
||||
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
|
||||
@ -780,7 +777,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
lambda x, y, z: jax.jvp(jax.jit(f), (x, y, z), (x, y, z)),
|
||||
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
||||
expected_jaxpr_debug_infos=[
|
||||
"None", # TODO(necula): missing debug info
|
||||
# TODO(necula): arg_names, result_paths
|
||||
"traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,",
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
expected_tracer_debug_infos=[
|
||||
@ -793,7 +791,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(unknown\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(unknown\)"),
|
||||
# TODO(necula): missing result names
|
||||
re.compile(r".*func.func public @main\(.*-> \(tensor<f..>, tensor<f..>, tensor<f..>, tensor<f..>\) {"),
|
||||
re.compile(r".*func.func public @main\(.*-> .*tensor<f..> {jax.result_info = \"\"}"),
|
||||
])
|
||||
|
||||
def test_vjp_of_jit(self):
|
||||
@ -805,6 +803,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])),
|
||||
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=",
|
||||
"None", # TODO(necula): missing debug info
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
@ -816,8 +815,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
# TODO(necula): missing arg_names
|
||||
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(unknown\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(unknown\)"),
|
||||
# TODO(necula): missing result names
|
||||
re.compile(r".*func.func public @main\(.*-> tensor<f..> {"),
|
||||
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
|
||||
])
|
||||
|
||||
def test_vjp_of_nested_jit(self):
|
||||
@ -837,6 +835,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=[0],[1]",
|
||||
# TODO(necula): result_paths
|
||||
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
|
||||
# TODO(necula): missing debug info
|
||||
"None",
|
||||
],
|
||||
@ -872,8 +872,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=",
|
||||
# TODO(necula): missing debug info
|
||||
'None',
|
||||
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): missing debug info
|
||||
@ -930,8 +929,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
"None"],
|
||||
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=",
|
||||
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=cond, fun=my_true_branch, arg_names=a,b",
|
||||
"traced_for=cond, fun=my_false_branch, arg_names=c,d"
|
||||
@ -957,8 +957,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
"None"],
|
||||
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=",
|
||||
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=",
|
||||
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=switch, fun=my_branch0, arg_names=x0",
|
||||
"traced_for=switch, fun=my_branch1, arg_names=x1",
|
||||
@ -1031,6 +1033,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]",
|
||||
# TODO(necula): bad result paths
|
||||
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
|
||||
'None', # TODO(necula): some Jaxprs without debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
@ -1038,6 +1042,14 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=scan, fun=f, arg_names=c,a",
|
||||
"traced_for=jit, fun=my_f, arg_names=x,as_",
|
||||
'None', # TODO(necula): some missing debug info
|
||||
],
|
||||
expected_lowering_lines=[
|
||||
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
|
||||
re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"),
|
||||
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"\[0\]\""),
|
||||
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""),
|
||||
# TODO(necula): unnamed function?
|
||||
re.compile(r".*func.func private @None"),
|
||||
])
|
||||
|
||||
def test_while_loop(self):
|
||||
@ -1059,7 +1071,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
'None', # TODO(necula): some missing debug info
|
||||
'traced_for=while_body, fun=my_body, arg_names=b, result_paths=',
|
||||
'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=',
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
@ -1080,7 +1093,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
||||
'None', # TODO(necula): some missing debug info
|
||||
# TODO(necula): bad arg_names, result_paths
|
||||
'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]',
|
||||
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): the arg_names are not right
|
||||
@ -1097,7 +1112,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=",
|
||||
'None', # TODO(necula): some missing debug info
|
||||
re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*/loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='),
|
||||
# TODO(necula): arg_names and result_paths are not right
|
||||
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): the arg_names are not right
|
||||
@ -1119,10 +1136,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp"
|
||||
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry"
|
||||
])
|
||||
|
||||
def test_eval_shape(self):
|
||||
@ -1179,6 +1197,17 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
|
||||
],
|
||||
expected_lowering_lines=[
|
||||
# TODO(necula): we did not DCE y?
|
||||
re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"),
|
||||
re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"),
|
||||
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"),
|
||||
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_pmap_of_grad(self):
|
||||
@ -1243,7 +1272,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
x, x_tan,
|
||||
expected_jaxpr_debug_infos=[
|
||||
'traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=[0],[1]',
|
||||
"None", # TODO(necula): missing debug info
|
||||
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=",
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
expected_tracer_debug_infos=[
|
||||
@ -1268,10 +1297,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
# TODO(necula): missing result_paths
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=y"
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y"
|
||||
])
|
||||
|
||||
def test_grad_remat(self):
|
||||
@ -1360,8 +1391,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): no leaked tracer from my_g_dce?
|
||||
"traced_for=custom_dce, fun=my_g, arg_names=x",
|
||||
@ -1388,8 +1419,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
# TODO(necula): bad arg_names (why None), bad result_paths
|
||||
'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]',
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): no leaked tracer from my_rule?
|
||||
@ -1427,7 +1459,14 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
re.compile(r"traced_for=jit, fun=_solve at .*scipy/linalg.py:.*, arg_names=a,b, result_paths="),
|
||||
re.compile(r"traced_for=jit, fun=solve at .*/linalg.py:.*, arg_names=a,b, result_paths="),
|
||||
re.compile(r"traced_for=jit, fun=_lu_solve at .*/linalg.py:.*, arg_names=lu,permutation,b, result_paths="),
|
||||
"None", # TODO(necula): there are missing jaxpr debug info
|
||||
# TODO(necula): why pointers to internal functions, arg_names, result_paths?
|
||||
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,None,x, result_paths='),
|
||||
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*/linalg.py:.*, arg_names=None,None,x, result_paths='),
|
||||
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*/linalg.py:.*, arg_names=None,x, result_paths='),
|
||||
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*/linalg.py:.*, arg_names=None,x, result_paths='),
|
||||
'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=',
|
||||
'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=',
|
||||
'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=',
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x",
|
||||
@ -1490,8 +1529,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
# TODO(necula): missing Jaxpr debug info
|
||||
"None"],
|
||||
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]",
|
||||
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j",
|
||||
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref",
|
||||
@ -1519,7 +1559,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=input, result_paths=",
|
||||
"None", # TODO(necula): missing tracer debug info
|
||||
# TODO(necula): function source location points in JAX internals
|
||||
# TODO(necula): arg_names and result_paths are wrong
|
||||
re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*/pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="),
|
||||
re.compile(r"traced_for=pallas_call index_map, fun=<lambda> at .*/pallas/core.py:.*, arg_names=, result_paths="),
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref",
|
||||
@ -1543,64 +1586,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=",
|
||||
"None"
|
||||
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=composite, fun=my_consts, arg_names=x"])
|
||||
|
||||
|
||||
class EagerPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
stack = contextlib.ExitStack()
|
||||
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
|
||||
stack.enter_context(jtu.ignore_warning(
|
||||
message="Some donated buffers were not usable", category=UserWarning))
|
||||
self.addCleanup(stack.close)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PythonPmapEagerTest(EagerPmapMixin, jtu.JaxTestCase):
|
||||
def test_pmap_lower_arg_info(self):
|
||||
def f(x, y, *args, **kwargs):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
lowered = jax.pmap(f).lower(
|
||||
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
|
||||
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
|
||||
def test_pmap_lower_result_info(self):
|
||||
def f(x, y, z):
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
|
||||
[jnp.array([3])])
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
def testLowerCompileArgTypeMismatch(self):
|
||||
f = jax.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(math.prod(shape), dtype=int).reshape(shape)
|
||||
x_f32 = x.astype(jnp.float32)
|
||||
x_i32 = x.astype(jnp.int32)
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Argument types differ .*"
|
||||
r"The mismatches are:\n"
|
||||
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user