[better_errors] Add debug info to more Jaxprs and WrappedFun (step 1)

The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.

We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.

We obtain several improvements in presence of debug infos
in debug_info_test.py
This commit is contained in:
George Necula 2025-01-24 10:57:28 +02:00
parent 7e353913f2
commit d12aead696
22 changed files with 270 additions and 280 deletions

View File

@ -420,9 +420,9 @@ def _trace_to_jaxpr(fun: Callable,
in_avals: Sequence[core.AbstractValue],
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' in msg:
@ -699,7 +699,8 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
assert next(ins_iter, None) is None
with source_info_util.extend_name_stack('rematted_computation'):
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info),
in_pvals, False)
# Transpose the linear jaxpr (which only has linear inputs).
out_cts_iter = iter(out_cts_flat)

View File

@ -61,7 +61,7 @@ from jax._src.api_util import (
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes)
flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
@ -1430,7 +1430,8 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=dbg)
del dbg
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
raise ValueError(msg) from None
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
f, res_paths = result_paths(f)
dbg = dbg.add_result_paths(res_paths)
f = lu.add_debug_info(f, dbg)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)

View File

@ -590,7 +590,7 @@ def debug_info(
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing.
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> core.DebugInfo:
@ -674,29 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
arg_names = args_arg_names + kwargs_arg_names
return arg_names
@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is not None:
debug = debug._replace(result_paths=tuple(result_paths))
else:
debug = debug.resolve_result_paths()
return jaxpr.replace(debug_info=debug)
def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
@ -721,7 +698,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()
# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
@ -735,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
@ -746,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")

View File

@ -1206,7 +1206,7 @@ def checkify(f: Callable[..., Out],
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)

View File

@ -2369,7 +2369,8 @@ class CallPrimitive(Primitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
jaxpr, ())
if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
@ -2402,7 +2403,7 @@ class MapPrimitive(Primitive):
map_primitive = True
def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
fun: lu.WrappedFun = fun_and_args[0]
args = fun_and_args[1:]
assert len(params['in_axes']) == len(args)
return trace.process_map(self, fun, args, params)
@ -2412,8 +2413,9 @@ class MapPrimitive(Primitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params

View File

@ -153,7 +153,7 @@ class custom_vmap:
lu.wrap_init(self.fun, debug_info=debug),
in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None

View File

@ -147,11 +147,11 @@ class custom_dce:
)
static_args = [args[i] for i in self.static_argnums]
dce_rule = api_util.prepend_static_args(
lu.wrap_init(self.dce_rule), static_args
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
)
else:
fun = lu.wrap_init(self.fun, debug_info=debug)
dce_rule = lu.wrap_init(self.dce_rule)
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
dyn_args = args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
@ -176,7 +176,7 @@ class custom_dce:
)
assert self.dce_rule is not None
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
flat_rule, in_avals, debug_rule
flat_rule, in_avals
)
# This second round of DCE is used to work out which inputs are actually
@ -191,7 +191,7 @@ class custom_dce:
return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
out_avals = closed_call.out_avals
out_flat = custom_dce_p.bind(
@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
# that most users of this API would compose this with a custom_jvp or
# custom_vjp, which makes this less urgent.
out = core.call_p.bind(
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
)
out_primals, out_tangents = util.split_list(out, [len(out_nz)])

View File

@ -485,13 +485,13 @@ class custom_partitioning:
_check_for_tracers(static_args)
else:
static_args = []
f_, dyn_args = lu.wrap_init(self.fun), args
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

View File

@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
@ -166,16 +166,17 @@ def _linearize_jaxpr(
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
del lin_trace, ans, tracers, new_arg
debug_info = jaxpr.jaxpr.debug_info
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
tangent_trace.invalidate()
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
@ -207,7 +208,7 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nzs = [type(t) is not Zero for t in out_tangents]
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
@ -1019,12 +1020,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
debug_info = jaxpr.jaxpr.debug_info
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
@lu.transformation_with_aux2

View File

@ -760,7 +760,8 @@ def _batch_jaxpr2(
axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
in_axes2, avals_in = unzip2([

View File

@ -1393,7 +1393,6 @@ def lower_jaxpr_to_fun(
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)
# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars

View File

@ -501,7 +501,7 @@ call_param_updaters[core.closed_call_p] = _closed_call_param_updater
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
lu.wrap_init(fun, params, debug_info=debug_info), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
@ -589,7 +589,7 @@ def trace_to_subjaxpr_nounits(
@lu.transformation2
def trace_to_subjaxpr_nounits2(
f,
f: Callable,
tag: TraceTag,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
@ -950,7 +950,9 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
return [*known_vals_out, *residuals]
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, debug_info=f.debug_info),
known_avals)
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
# check jaxpr_known and jaxpr_unknown in isolation
@ -1124,7 +1126,7 @@ def _partial_eval_jaxpr_custom_cached(
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
known_outvars, known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
known_eqns, known_effects)
known_eqns, known_effects, jaxpr.debug_info)
config.enable_checks.value and core.check_jaxpr(jaxpr_known)
_, ins_staged = partition_list(in_inst, jaxpr.invars)
@ -1336,8 +1338,7 @@ def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
jaxpr.debug_info.filter_result_paths(used_outputs))
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
@ -1424,10 +1425,8 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)),
used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
jaxpr.debug_info.filter_arg_names(used_inputs),
jaxpr.debug_info.filter_result_paths(used_outputs))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
@ -1644,7 +1643,9 @@ class JaxprStackFrame:
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer]
def to_jaxpr(self, trace: DynamicJaxprTrace,
out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None,
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
@ -1657,7 +1658,8 @@ class JaxprStackFrame:
outvars = state_outvars + explicit_outvars
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
debug_info)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
@ -1950,7 +1952,7 @@ class DynamicJaxprTrace(core.Trace):
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals, f.debug_info)
f, reduced_in_avals)
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -2074,8 +2076,9 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return out_tracers
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
return self.frame.to_jaxpr(self, out_tracers)
def to_jaxpr(self, out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None):
return self.frame.to_jaxpr(self, out_tracers, debug_info)
custom_staging_rules: dict[Primitive, Callable] = {}
@ -2116,14 +2119,12 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: core.DebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
trace = DynamicJaxprTrace(debug_info)
trace = DynamicJaxprTrace(fun.debug_info)
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
@ -2131,8 +2132,8 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_returned_refs(debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
_check_no_returned_refs(fun.debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info)
del trace, fun, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
@ -2160,7 +2161,7 @@ def _check_no_returned_refs(
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
@ -2172,10 +2173,10 @@ def _check_no_returned_refs(
@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None
fun: lu.WrappedFun,
) -> tuple[Jaxpr, OutputType, list[Any]]:
trace = DynamicJaxprTrace(debug_info)
trace = DynamicJaxprTrace(fun.debug_info)
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
in_avals, keep_inputs = unzip2(fun.in_type)
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)

View File

@ -33,7 +33,6 @@ import numpy as np
import jax
from jax._src import api
from jax._src import api_util
from jax._src import compiler
from jax._src import config
from jax._src import core
@ -652,7 +651,6 @@ class ParallelCallableInfo:
in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue]
debug_info: core.DebugInfo | None
@cached_property
def local_devices(self):
@ -723,8 +721,7 @@ def stage_parallel_callable(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pci.debug_info)
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
fun, sharded_avals)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
@ -758,7 +755,7 @@ def get_pmap_jaxpr(
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals, fun.debug_info)
in_axes, out_axes_thunk, avals)
with core.extend_axis_env_nd([(axis_name, axis_size)]):
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
@ -992,7 +989,7 @@ class UnloadedPmapExecutable:
return PmapExecutable(
self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self.jaxpr_debug_info, self)
self.local_input_avals, self)
@staticmethod
def from_hlo(hlo: ir.Module,
@ -1119,24 +1116,23 @@ class UnloadedPmapExecutable:
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_jaxpr_debug_info",
"_unloaded_executable"]
"fingerprint", "in_avals", "_unloaded_executable"]
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
in_avals, jaxpr_debug_info, unloaded_executable):
in_avals,
unloaded_executable: UnloadedPmapExecutable):
self.xla_executable = xla_executable
self._unsafe_call = None
self.build_unsafe_call = build_unsafe_call
self.fingerprint = fingerprint
self.in_avals = in_avals
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
return self._unsafe_call # type: ignore
# -- stages.XlaExecutable overrides
@ -1147,7 +1143,8 @@ class PmapExecutable(stages.XlaExecutable):
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(core.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
check_arg_avals_for_call(self.in_avals, arg_avals,
self._unloaded_executable.jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
@ -3206,7 +3203,7 @@ def check_arg_avals_for_call(ref_avals, arg_avals,
f"but called with {len(arg_avals)}")
if jaxpr_debug_info is not None:
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
else:
num_args = len(ref_avals)
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]

View File

@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable,
lu.wrap_init(fun, debug_info=debug_info),
in_tree)
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
wrapped_fun, in_avals, debug_info)
wrapped_fun, in_avals)
return jaxpr, consts, out_tree(), attrs_tracked
@weakref_lru_cache

View File

@ -746,8 +746,9 @@ def _trace_composite_to_jaxpr(fun: Callable,
in_avals: Sequence[core.AbstractValue],
name: str,
debug_info: core.DebugInfo):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info), in_tree)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError(
"Found a JAX Tracer as a constant in the decomposition for the "

View File

@ -63,7 +63,7 @@ data must be immutable, because it will be stored in function memoization tables
"""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, NamedTuple
import weakref
@ -71,6 +71,7 @@ import weakref
from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import keystr, generate_key_paths
from jax._src.util import curry, cache_clearing_funs, HashableFunction
@ -275,13 +276,6 @@ class DebugInfo(NamedTuple):
# e.g. ('[0]', '[1]', ...)
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
def add_result_paths(self,
result_paths_thunk: Callable[[], tuple[str, ...]]
) -> DebugInfo:
assert self.result_paths is None
return self._replace(result_paths=HashableFunction(result_paths_thunk,
closure=()))
def resolve_result_paths(self) -> DebugInfo:
"""Return a debug info with resolved result paths."""
if callable(self.result_paths):
@ -296,6 +290,10 @@ class DebugInfo(NamedTuple):
# TODO(necula): this should not happen
return (None,) * expected
def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]:
"""Keep only the arg_names for which `keep` is True."""
return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b)
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
"""Get the result paths with a safety check."""
assert not callable(self.result_paths), self
@ -305,15 +303,34 @@ class DebugInfo(NamedTuple):
# TODO(necula): this should not happen
return ("",) * expected
def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
"""Keep only the result_paths for which `keep` is True."""
assert not callable(self.result_paths), self
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)
def wrap_init(f: Callable, params=None, *,
debug_info: DebugInfo | None = None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
if debug_info:
if debug_info.result_paths is None:
fun, result_paths_thunk = _get_result_paths_thunk(fun)
debug_info = debug_info._replace(
result_paths=HashableFunction(result_paths_thunk, closure=()))
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
fun.params, fun.in_type, debug_info)
return fun
@transformation_with_aux2
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
assert f.in_type is None
if in_type is None:
@ -350,16 +367,9 @@ def _check_input_type(in_type: core.InputType) -> None:
provided[d.val] = True
assert all(provided)
def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None
) -> WrappedFun:
"""Produce a new WrappedFun with debug_info attached."""
assert f.debug_info is None
if debug_info is None:
return f
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info)
def cache(call: Callable, *, explain: Callable | None = None):
def cache(call: Callable, *,
explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None):
"""Memoization decorator for functions taking a WrappedFun as first argument.
Args:
@ -367,6 +377,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
underlying transforms and params on the WrappedFun are used as part of the
memoization cache key.
explain: a function that is invoked upon cache misses to log an explanation
of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`.
Returns:
A memoized version of ``call``.
"""
@ -382,7 +395,7 @@ def cache(call: Callable, *, explain: Callable | None = None):
else:
ans = call(fun, *args)
if explain and config.explain_cache_misses.value:
explain(fun.f, cache is new_cache, cache, key)
explain(fun, cache is new_cache, cache, key)
cache[key] = (ans, fun.stores)
return ans

View File

@ -423,7 +423,7 @@ class BlockSpec:
)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
flat_index_map_fun, index_map_avals, debug_info=debug
flat_index_map_fun, index_map_avals
)
mapped_block_shape = tuple(mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):

View File

@ -101,12 +101,12 @@ def _pallas_call_jvp_rule(
primals,
tangents,
*,
jaxpr,
jaxpr: jax_core.Jaxpr,
name_and_src_info,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping,
debug,
interpret,
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
@ -1098,13 +1098,14 @@ def pallas_call_checkify_rule(error: checkify.Error,
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
*error_memref_aval, *output_aval, *scratch_aval]
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {})
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
wrapped_kernel_with_err, jaxpr_flat_avals)
# Prepare pallas_call inputs. We need to create new block specs
# for the new error inputs and outputs.
@ -1161,16 +1162,16 @@ def _trace_kernel_to_jaxpr(
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
indexer: bool = False,
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), kernel_in_tree)
lu.wrap_init(fun, debug_info=debug), kernel_in_tree)
wrapped_kernel_fun = primitives.wrap_with_transforms(
wrapped_kernel_fun, kernel_in_transforms
)
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
kernel_avals)
if consts:
consts_avals = [jax_core.get_aval(c) for c in consts]
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):

View File

@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
@ -567,9 +567,7 @@ def _infer_params_impl(
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
f = lu.wrap_init(fun, debug_info=dbg)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args
@ -618,7 +616,7 @@ def _infer_params_impl(
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
ji.in_layouts_treedef, ji.in_layouts_leaves,
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs)
attr_token = _attr_token(flat_fun, in_type)
@ -627,8 +625,7 @@ def _infer_params_impl(
if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh())
with mesh_lib.set_abstract_mesh(abstract_mesh):
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
flat_fun, in_type, attr_token,
IgnoreKey(ji.inline))
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
@ -1171,17 +1168,18 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
callsites: set[str] = set()
def explain_tracing_cache_miss(
f: Callable, unseen_f: bool, cache: dict, key: tuple):
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
if config.check_tracer_leaks.value: return
def unpack(key):
transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key
transforms, (), _, (in_type, _, inline), *_, ctx = key
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
_, (_, (in_tree,)), *_ = transforms
return in_tree, in_type, debug_info, inline.val, ctx
in_tree, in_type, debug_info, inline, ctx = unpack(key)
return in_tree, in_type, inline.val, ctx
in_tree, in_type, inline, ctx = unpack(key)
if inline: return
debug_info = fun.debug_info
msg: list[str] = []
p = msg.append
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
@ -1190,7 +1188,7 @@ def explain_tracing_cache_miss(
p(f"TRACING CACHE MISS at {callsite} because:")
# have we seen this function before at all?
fun_name = getattr(f, '__qualname__', f)
fun_name = getattr(fun.f, '__qualname__', fun.f)
if debug_info is not None and debug_info.func_src_info:
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
@ -1198,7 +1196,7 @@ def explain_tracing_cache_miss(
else:
src_info = ''
if unseen_f:
p(f" never seen function:\n {fun_name} id={id(f)}{src_info}")
p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}")
if callsite in callsites:
p(" but seen another function defined on the same line; maybe the function is\n"
" being re-defined repeatedly, preventing caching?")
@ -1263,7 +1261,7 @@ def explain_tracing_cache_miss(
# have we never seen these input types (eg shapes, dtypes) before?
types_match = [k for k in trees_match if k[1] == in_type]
if not types_match:
if len(in_type) < 5:
if len(in_type) < 5 and debug_info is not None:
in_type_str = ':\n {}'.format(', '.join(
f'{n}: {ty.str_short(short_dtypes=True)}'
for n, ty in zip(debug_info.arg_names, in_type)))
@ -1275,7 +1273,12 @@ def explain_tracing_cache_miss(
num_mismatch = sum(map(op.ne, closest_ty, in_type))
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
add_weak_type_hint = False
for name, ty1, ty2 in zip(debug_info.arg_names, closest_ty, in_type):
if debug_info:
arg_names = debug_info.safe_arg_names(len(in_type))
else:
arg_names = (None,) * len(in_type)
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
if ty1 != ty2:
if type(ty1) == type(ty2) == core.ShapedArray:
s1, s2 = ty1.str_short(True), ty2.str_short(True)
@ -1302,8 +1305,6 @@ def _create_pjit_jaxpr(
fun: lu.WrappedFun,
in_type: core.InputType | Sequence[core.AbstractValue],
attr_data: int,
debug_info: core.DebugInfo,
result_paths: Callable,
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
@ -1317,17 +1318,13 @@ def _create_pjit_jaxpr(
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
lu.annotate(fun, cast(core.InputType, in_type)))
attrs_tracked = []
else:
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=debug_info)
fun, in_type)
# assert attr_data is sentinel or attr_data matches attrs_tracked
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
@ -1928,7 +1925,9 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
name: str, jaxpr: core.ClosedJaxpr,
effects, in_shardings,
out_shardings, in_layouts, out_layouts,
api_name):
mod_ctx = ctx.module_context
@ -1959,7 +1958,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
return func
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
jaxpr: core.ClosedJaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, keep_unused, inline, compiler_options_kvs):
effects = list(ctx.tokens_in.effects())
@ -1987,8 +1987,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
mlir.register_lowering(pjit_p, _pjit_lowering)
def _pjit_batcher(axis_data, vals_in, dims_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
def _pjit_batcher(axis_data, vals_in,
dims_in: tuple[int, ...],
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
@ -2037,7 +2039,8 @@ batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
def _pjit_batcher_for_sharding(
s: Sharding | UnspecifiedValue,
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
ndim: int):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim)
@ -2049,7 +2052,7 @@ def _pjit_batcher_for_sharding(
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
new_op = hlo_s.to_proto().clone()
tad = list(new_op.tile_assignment_dimensions)
tad.insert(dim, 1)
tad.insert(dim, 1) # type: ignore
new_op.tile_assignment_dimensions = tad
new_gs = GSPMDSharding(
s._device_assignment, new_op,
@ -2171,8 +2174,9 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
ad.primitive_linearizations[pjit_p] = _pjit_linearization
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
def _pjit_partial_eval(trace: pe.JaxprTrace,
*in_tracers,
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,
name, keep_unused, inline, compiler_options_kvs):
in_pvals = [t.pval for t in in_tracers]
@ -2191,7 +2195,7 @@ def _pjit_partial_eval(trace, *in_tracers,
else:
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs)
unknown_outs = tuple(unknown_outs) # type: ignore[assignment]
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals)
res_shardings = (UNSPECIFIED,) * num_residuals
@ -2282,7 +2286,7 @@ def _pjit_partial_eval(trace, *in_tracers,
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_out_avals = unknown_jaxpr.out_avals
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) # type: ignore
for aval in unknown_out_avals
]
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),

View File

@ -998,9 +998,10 @@ def _initial_style_jaxpr(fun: Callable,
in_tree: api_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: core.DebugInfo):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug),
tree_util.treedef_tuple((in_tree,)))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals)
return jaxpr, consts, out_tree_thunk()

View File

@ -88,7 +88,7 @@ from jax._src.interpreters.partial_eval import (
# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes
def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa
jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic(
fun, in_avals, debug_info, keep_inputs=keep_inputs)
fun, in_avals, keep_inputs=keep_inputs)
return jaxpr, out_avals, consts

View File

@ -14,9 +14,7 @@
from __future__ import annotations
import contextlib
import functools
import math
import operator
import re
from typing import Any
@ -74,12 +72,12 @@ def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
# Strip the absolute path and the line number but check that it references
# this file (to catch errors when the source info points in JAX internals)
fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info)
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}"
arg_names_str = ",".join([str(a) for a in dbg.arg_names])
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={arg_names_str}"
if isinstance(dbg.result_paths, tuple):
if dbg.result_paths:
res += f", result_paths={','.join(dbg.result_paths)}"
else:
res += ", result_paths=<empty>"
res += f", result_paths={','.join(dbg.result_paths)}"
elif dbg.result_paths is None:
res += ", result_paths=<empty>"
return res
@ -151,7 +149,8 @@ class DebugInfoTest(jtu.JaxTestCase):
found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
for j in all_jaxprs]
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos) # JAXPRS
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos,
"Jaxprs debug_infos") # JAXPRS
found_tracer_debug_infos = []
if tracer_spy is not None:
@ -173,7 +172,8 @@ class DebugInfoTest(jtu.JaxTestCase):
else:
found_tracer_debug_infos.append("None")
self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos) # INSPECTED TRACERS
self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos,
"Tracer debug_infos") # INSPECTED TRACERS
if not check_lowering: return
# Collect all the lines in all the MLIR modules
@ -186,36 +186,34 @@ class DebugInfoTest(jtu.JaxTestCase):
mlir_modules_lines.extend(
mlir.module_to_string(mod, enable_debug_info=True).split("\n"))
expected_and_found = set()
expected_and_not_found = set()
for exp in expected_lowering_lines:
for l in mlir_modules_lines:
ok = exp.match(l) if isinstance(exp, re.Pattern) else exp == l
if ok:
expected_and_found.add(exp)
break
else:
expected_and_not_found.add(exp)
if expected_and_not_found:
msg = "\n".join(mlir_modules_lines)
self.assertEmpty(expected_and_not_found, "\nNot found in the MLIR module lines:\n" + msg)
self._check_matches(expected_lowering_lines, mlir_modules_lines,
"MLIR module lines", report_found_unexpected=False)
def _check_matches(self,
expected: list[str | re.Pattern],
found: list[str]):
expected_and_found = set()
unexpected: set[str] = set()
for debug_info in found:
for exp_re in expected:
ok = exp_re.match(debug_info) if isinstance(exp_re, re.Pattern) else exp_re == debug_info
found: list[str],
what: str,
report_found_unexpected: bool = True):
expected_and_found: set[str | re.Pattern] = set()
found_and_expected: set[str] = set()
for exp_re in expected:
for found_line in found:
ok = exp_re.match(found_line) if isinstance(exp_re, re.Pattern) else exp_re == found_line
if ok:
expected_and_found.add(exp_re)
break
else:
unexpected.add(debug_info)
self.assertEmpty(unexpected) # found unexpected debug_info
self.assertEmpty([e for e in expected if e not in expected_and_found]) # expected element that was not found
found_and_expected.add(found_line)
found_and_unexpected = set(found) - found_and_expected
all_found = "\n ".join(found)
if report_found_unexpected and found_and_unexpected:
unexp_str = "\n ".join(found_and_unexpected)
msg = f"Found unexpected {what}:\n {unexp_str}\nAll found {what}:\n {all_found}"
self.assertTrue(False, msg)
if expected_not_found := {e for e in expected if e not in expected_and_found}:
exp_str = "\n ".join([str(e) for e in expected_not_found])
msg = f"Expected but not found in {what}:\n {exp_str}\nAll found {what}:\n {all_found}"
self.assertTrue(False, msg)
def test_debug_info_basic(self):
def my_f(x, y, z, w):
@ -634,8 +632,7 @@ class DebugInfoTest(jtu.JaxTestCase):
3,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): bad result names
'traced_for=jit, fun=my_f, arg_names=a, result_paths=<empty>',
'traced_for=jit, fun=my_f, arg_names=a, result_paths=',
'traced_for=jit, fun=my_g, arg_names=b, result_paths=',
],
check_tracer_arg_name=True,
@ -694,7 +691,7 @@ class DebugInfoTest(jtu.JaxTestCase):
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
"None", # TODO(necula)
"None", # TODO(necula) missing debug info
],
expected_lowering_lines=[
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
@ -780,7 +777,8 @@ class DebugInfoTest(jtu.JaxTestCase):
lambda x, y, z: jax.jvp(jax.jit(f), (x, y, z), (x, y, z)),
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
expected_jaxpr_debug_infos=[
"None", # TODO(necula): missing debug info
# TODO(necula): arg_names, result_paths
"traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
@ -793,7 +791,7 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(unknown\)"),
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(unknown\)"),
# TODO(necula): missing result names
re.compile(r".*func.func public @main\(.*-> \(tensor<f..>, tensor<f..>, tensor<f..>, tensor<f..>\) {"),
re.compile(r".*func.func public @main\(.*-> .*tensor<f..> {jax.result_info = \"\"}"),
])
def test_vjp_of_jit(self):
@ -805,6 +803,7 @@ class DebugInfoTest(jtu.JaxTestCase):
lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])),
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=",
"None", # TODO(necula): missing debug info
],
tracer_spy=tracer_spy,
@ -816,8 +815,7 @@ class DebugInfoTest(jtu.JaxTestCase):
# TODO(necula): missing arg_names
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(unknown\)"),
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(unknown\)"),
# TODO(necula): missing result names
re.compile(r".*func.func public @main\(.*-> tensor<f..> {"),
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
])
def test_vjp_of_nested_jit(self):
@ -837,6 +835,8 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=[0],[1]",
# TODO(necula): result_paths
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
# TODO(necula): missing debug info
"None",
],
@ -872,8 +872,7 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=",
# TODO(necula): missing debug info
'None',
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']",
],
expected_tracer_debug_infos=[
# TODO(necula): missing debug info
@ -930,8 +929,9 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): some Jaxprs without debug info
"None"],
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=",
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=cond, fun=my_true_branch, arg_names=a,b",
"traced_for=cond, fun=my_false_branch, arg_names=c,d"
@ -957,8 +957,10 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): some Jaxprs without debug info
"None"],
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=",
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=",
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=switch, fun=my_branch0, arg_names=x0",
"traced_for=switch, fun=my_branch1, arg_names=x1",
@ -1031,6 +1033,8 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]",
# TODO(necula): bad result paths
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
'None', # TODO(necula): some Jaxprs without debug info
],
expected_tracer_debug_infos=[
@ -1038,6 +1042,14 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=scan, fun=f, arg_names=c,a",
"traced_for=jit, fun=my_f, arg_names=x,as_",
'None', # TODO(necula): some missing debug info
],
expected_lowering_lines=[
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"),
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"\[0\]\""),
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""),
# TODO(necula): unnamed function?
re.compile(r".*func.func private @None"),
])
def test_while_loop(self):
@ -1059,7 +1071,8 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
'None', # TODO(necula): some missing debug info
'traced_for=while_body, fun=my_body, arg_names=b, result_paths=',
'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=',
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
@ -1080,7 +1093,9 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
'None', # TODO(necula): some missing debug info
# TODO(necula): bad arg_names, result_paths
'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]',
],
expected_tracer_debug_infos=[
# TODO(necula): the arg_names are not right
@ -1097,7 +1112,9 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=",
'None', # TODO(necula): some missing debug info
re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*/loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='),
# TODO(necula): arg_names and result_paths are not right
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]",
],
expected_tracer_debug_infos=[
# TODO(necula): the arg_names are not right
@ -1119,10 +1136,11 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]",
# TODO(necula): some Jaxprs without debug info
'None'],
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp"
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry"
])
def test_eval_shape(self):
@ -1179,6 +1197,17 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_tracer_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
],
expected_lowering_lines=[
# TODO(necula): we did not DCE y?
re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"),
re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"),
re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"),
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"),
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"),
]
)
def test_pmap_of_grad(self):
@ -1243,7 +1272,7 @@ class DebugInfoTest(jtu.JaxTestCase):
x, x_tan,
expected_jaxpr_debug_infos=[
'traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=[0],[1]',
"None", # TODO(necula): missing debug info
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
@ -1268,10 +1297,12 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): some Jaxprs without debug info
'None'],
# TODO(necula): missing result_paths
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=my_g, arg_names=y"
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y"
])
def test_grad_remat(self):
@ -1360,8 +1391,8 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
# TODO(necula): some Jaxprs without debug info
'None'],
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]",
],
expected_tracer_debug_infos=[
# TODO(necula): no leaked tracer from my_g_dce?
"traced_for=custom_dce, fun=my_g, arg_names=x",
@ -1388,8 +1419,9 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
# TODO(necula): some Jaxprs without debug info
'None'],
# TODO(necula): bad arg_names (why None), bad result_paths
'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]',
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): no leaked tracer from my_rule?
@ -1427,7 +1459,14 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r"traced_for=jit, fun=_solve at .*scipy/linalg.py:.*, arg_names=a,b, result_paths="),
re.compile(r"traced_for=jit, fun=solve at .*/linalg.py:.*, arg_names=a,b, result_paths="),
re.compile(r"traced_for=jit, fun=_lu_solve at .*/linalg.py:.*, arg_names=lu,permutation,b, result_paths="),
"None", # TODO(necula): there are missing jaxpr debug info
# TODO(necula): why pointers to internal functions, arg_names, result_paths?
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,None,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*/linalg.py:.*, arg_names=None,None,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*/linalg.py:.*, arg_names=None,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*/linalg.py:.*, arg_names=None,x, result_paths='),
'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=',
'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=',
'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=',
],
expected_tracer_debug_infos=[
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x",
@ -1490,8 +1529,9 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): missing Jaxpr debug info
"None"],
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]",
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j",
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref",
@ -1519,7 +1559,10 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=input, result_paths=",
"None", # TODO(necula): missing tracer debug info
# TODO(necula): function source location points in JAX internals
# TODO(necula): arg_names and result_paths are wrong
re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*/pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="),
re.compile(r"traced_for=pallas_call index_map, fun=<lambda> at .*/pallas/core.py:.*, arg_names=, result_paths="),
],
expected_tracer_debug_infos=[
"traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref",
@ -1543,64 +1586,11 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=",
"None"
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=composite, fun=my_consts, arg_names=x"])
class EagerPmapMixin:
def setUp(self):
super().setUp()
stack = contextlib.ExitStack()
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
stack.enter_context(jtu.ignore_warning(
message="Some donated buffers were not usable", category=UserWarning))
self.addCleanup(stack.close)
@jtu.pytest_mark_if_available('multiaccelerator')
class PythonPmapEagerTest(EagerPmapMixin, jtu.JaxTestCase):
def test_pmap_lower_arg_info(self):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.pmap(f).lower(
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
def test_pmap_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
[jnp.array([3])])
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def testLowerCompileArgTypeMismatch(self):
f = jax.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=int).reshape(shape)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())