[better_errors] Refactor more uses of pe.tracing_debug_info (part 3)

We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).

In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.

One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
This commit is contained in:
George Necula 2025-01-25 18:34:38 +02:00
parent d223dfc3f7
commit 32c98b9a76
11 changed files with 112 additions and 122 deletions

View File

@ -61,7 +61,7 @@ from jax._src.api_util import (
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes, debug_info_final)
result_paths, flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
@ -1420,14 +1420,14 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, in_devices, backend_name,
axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
dbg = tracing_debug_info(
'pmap', fun, args, kwargs,
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)
f = lu.wrap_init(fun)
@ -1478,9 +1478,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
f, res_paths = result_paths(f)
dbg = dbg.add_result_paths(res_paths)
f = lu.add_debug_info(f, dbg)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun = debug_info_final(flat_fun, dbg, res_paths)
is_explicit_global_axis_size = axis_size is not None
global_axis_size = _get_global_axis_size(local_axis_size, in_devices,

View File

@ -100,29 +100,6 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
ans = fun(*args)
return tree_unflatten(out_tree, ans)
def flattened_fun_in_tree(
fn: lu.WrappedFun
) -> tuple[PyTreeDef, Callable[[], PyTreeDef], bool] | None:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
try:
((in_tree,), out_tree_store, has_kwargs), = (
(args, store, f is flatten_fun.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
# When `fn` is not the result of flatten_fun or flatten_fun_nokwargs
return None
else:
return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr]
@lu.transformation_with_aux2
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
@ -705,6 +682,7 @@ def result_paths(_fun, _store, *args, **kwargs):
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
@ -712,7 +690,8 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
return jaxpr
assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
@ -720,16 +699,6 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)
def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
res_paths_thunk: Callable[[], tuple[str, ...]]
) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths_thunk is None
res_paths_thunk_ = HashableFunction(res_paths_thunk, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths_thunk=res_paths_thunk_))
def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):

View File

@ -27,6 +27,7 @@ from jax import lax
from jax.experimental import shard_map
from jax._src import api
from jax._src import api_util
from jax._src import ad_checkpoint
from jax._src import linear_util as lu
from jax._src import config
@ -39,7 +40,6 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -1202,8 +1202,10 @@ def checkify(f: Callable[..., Out],
in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:

View File

@ -91,6 +91,22 @@ class JaxprDebugInfo(NamedTuple):
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
"""Get the result_paths with a safety check."""
if len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
@ -149,8 +165,9 @@ class Jaxpr:
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
def __str__(self):
return str(self.pretty_print())

View File

@ -473,6 +473,9 @@ class custom_partitioning:
def __call__(self, *args, **kwargs):
args = _resolve_kwargs(self.fun, args, kwargs)
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
if self.static_argnums:
static_argnums = set(self.static_argnums)
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
@ -491,8 +494,6 @@ class custom_partitioning:
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)

View File

@ -35,7 +35,6 @@ from jax._src import profiler
from jax._src import source_info_util
from jax._src import compute_on
from jax._src import xla_metadata as xla_metadata_lib
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs)
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
Var, DropVar, Atom,
@ -44,7 +43,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
@ -1913,8 +1912,7 @@ class DynamicJaxprTrace(core.Trace):
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
dbg = tracing_debug_info_final(f, call_primitive.name)
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg)
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
propagate_source_info=False)
@ -1940,7 +1938,8 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
def process_map(self, map_primitive, f, tracers, params):
def process_map(self, map_primitive, f: lu.WrappedFun,
tracers: Sequence[core.Tracer], params):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
axis_name, axis_size = params['axis_name'], params['axis_size']
@ -1949,8 +1948,7 @@ class DynamicJaxprTrace(core.Trace):
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals,
debug_info=tracing_debug_info_final(f, map_primitive.name))
f, reduced_in_avals, f.debug_info)
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -2050,7 +2048,7 @@ class DynamicJaxprTrace(core.Trace):
closed_call_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(call_jaxpr), ())
transpose_flat, in_tree2 = flatten_fun_nokwargs(
transpose_flat, in_tree2 = api_util.flatten_fun_nokwargs(
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
@ -2111,42 +2109,6 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
store.store(out_zeros)
return [*out_primals, *out_nz_tangents]
# Callers should be using linear_util.debug_info instead!
def tracing_debug_info(
fn: Callable,
in_tree: PyTreeDef,
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
# TODO(necula): in general we can just pretend the leaves are booleans, but
# when we use custom pytrees, the flattening functions may check the type
# of the argument
try:
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
except:
# TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit
dummy_args = ([False], {}) if has_kwargs else [False]
args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore
def res_paths_thunk() -> tuple[str, ...]:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
return tuple(tree_util.keystr(path)
for path, _ in tree_util.generate_key_paths(dummy_result))
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
result_paths_thunk=res_paths_thunk)
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
fn_trees = flattened_fun_in_tree(fn)
if fn_trees is None:
# TODO(necula): eliminate this branch
return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f),
(None,), None)
in_tree, out_tree_thunk, has_kws = fn_trees
return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for)
@profiler.annotate_function
def trace_to_jaxpr_dynamic(

View File

@ -652,6 +652,7 @@ class ParallelCallableInfo:
in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue]
debug_info: api_util.TracingDebugInfo | None
@cached_property
def local_devices(self):
@ -722,8 +723,8 @@ def stage_parallel_callable(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
fun, sharded_avals, pci.debug_info)
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
@ -757,7 +758,7 @@ def get_pmap_jaxpr(
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
in_axes, out_axes_thunk, avals, fun.debug_info)
with core.extend_axis_env_nd([(axis_name, axis_size)]):
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
@ -880,8 +881,8 @@ def lower_parallel_callable(
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
num_replicas=replicas.num_global_replicas,
lowering_parameters=lowering_parameters)
return PmapComputation(lowering_result.module,
@ -1971,8 +1972,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
result_shardings=out_mlir_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
num_replicas=nreps,
num_partitions=num_partitions,
all_default_mem_kind=all_default_mem_kind,

View File

@ -71,7 +71,7 @@ import weakref
from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.util import curry, cache_clearing_funs
from jax._src.util import curry, cache_clearing_funs, HashableFunction
traceback_util.register_exclusion(__file__)
@ -175,11 +175,11 @@ class WrappedFun:
if out_store is None:
return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args),
((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
(out_store,) + self.stores, self.params, None, self.debug_info)
else:
return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args),
((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
(out_store,) + self.stores, self.params, None, self.debug_info)
def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
@ -282,6 +282,21 @@ class TracingDebugInfo(NamedTuple):
jaxpr_dbg.arg_names,
lambda: jaxpr_dbg.result_paths)
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
) -> TracingDebugInfo:
assert self.result_paths_thunk is None
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
closure=()))
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected
def wrap_init(f: Callable, params=None, *,
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""

View File

@ -569,6 +569,7 @@ def _infer_params_impl(
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args
@ -1160,7 +1161,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
if not config.dynamic_shapes.value and not attrs_tracked:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.arg_names,
None if debug_info is None else debug_info.safe_arg_names(len(in_avals)),
"pjit arguments", allow_uneven_sharding=False)
check_aval_layout_compatibility(
in_layouts_flat, in_avals,
@ -1302,7 +1303,7 @@ def _create_pjit_jaxpr(
in_type: core.InputType | Sequence[core.AbstractValue],
attr_data: int,
debug_info: lu.TracingDebugInfo,
out_paths: Callable,
result_paths: Callable,
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
@ -1314,19 +1315,18 @@ def _create_pjit_jaxpr(
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
attrs_tracked = []
else:
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=pe_debug)
fun, in_type, debug_info=debug_info)
# assert attr_data is sentinel or attr_data matches attrs_tracked
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
if config.debug_key_reuse.value:
# Import here to avoid circular imports
@ -1366,11 +1366,12 @@ def _check_and_canonicalize_out_shardings(
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(
out_shardings_flat, out_avals,
None if debug_info is None else debug_info.result_paths,
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
"pjit outputs", allow_uneven_sharding=False)
check_aval_layout_compatibility(
out_layouts_flat, out_avals,
None if debug_info is None else debug_info.result_paths, "jit outputs")
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
"jit outputs")
return out_shardings_flat, out_layouts_flat
@ -1423,9 +1424,9 @@ class IgnoreKey:
def pjit_check_aval_sharding(
shardings, flat_avals, names: tuple[str, ...] | None,
shardings, flat_avals, names: tuple[str | None, ...] | None,
what_aval: str, allow_uneven_sharding: bool):
new_names = [''] * len(shardings) if names is None else names
new_names = [None] * len(shardings) if names is None else names
for aval, s, name in zip(flat_avals, shardings, new_names):
if isinstance(s, (UnspecifiedValue, AUTO)):
continue

View File

@ -57,8 +57,6 @@ from jax._src.interpreters.partial_eval import (
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
dce_jaxpr_consts as dce_jaxpr_consts,
dce_rules as dce_rules,
tracing_debug_info as tracing_debug_info,
tracing_debug_info_final as tracing_debug_info_final,
def_trivial_padding as def_trivial_padding,
forwarding_rules as forwarding_rules,
has_effects as has_effects,

View File

@ -639,8 +639,7 @@ class DebugInfoTest(jtu.JaxTestCase):
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_g, arg_names=b, from b",
# TODO(necula): bad arg name
"traced_for=jit, fun=my_f, arg_names=args[0], from args[0]"
"traced_for=jit, fun=my_f, arg_names=a, from a",
])
def test_jit_arg_names(self):
@ -692,8 +691,7 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): the arg_names are not right, include static ones, also missing args[1]
"traced_for=jit, fun=my_f, arg_names=x['hi'],y,z,args[0],kwargs['t'],kwargs['w'], from kwargs['w']",
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
"None", # TODO(necula)
],
expected_lowering_lines=[
@ -1177,8 +1175,7 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): bad tracer provenance
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],a,kwargs['b'],kwargs['d'], from args[0]",
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
],
)
@ -1219,6 +1216,31 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
# TODO(necula): why this?
re.compile(r'traced_for=jit, fun=_multi_slice at .*/array_methods.py:.*, arg_names=self, result_paths=.*'),
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
# TODO(necula): missing debug_info
'None'
],
)
@jtu.ignore_warning(category=UserWarning,
message=".* jitted function .* includes a pmap")
def test_jvp_pmap(self):
tracer_spy = TracerSpy()
def my_f(x, y):
tracer_spy.append(x)
return jnp.sin(x) + y
x = np.ones((jax.device_count(), 1), dtype=np.float32)
x_tan = np.full_like(x, .1)
self._check_tracers_and_jaxprs(
jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))),
x, x_tan,
expected_jaxpr_debug_infos=[
'traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=[0],[1]',
"None", # TODO(necula): missing debug info
],
tracer_spy=tracer_spy,
@ -1307,7 +1329,7 @@ class DebugInfoTest(jtu.JaxTestCase):
"None", # TODO(necula): missing tracer debug info
],
expected_tracer_debug_infos=[
"traced_for=xla_pmap, fun=my_f, arg_names=my_x",
"traced_for=pmap, fun=my_f, arg_names=my_x",
],
check_lowering=False, # TODO(necula): warning during lowering
)
@ -1339,7 +1361,8 @@ class DebugInfoTest(jtu.JaxTestCase):
# TODO(necula): some Jaxprs without debug info
'None'],
expected_tracer_debug_infos=[
"traced_for=custom_dce, fun=my_g, arg_names=x"
# TODO(necula): no leaked tracer from my_g_dce?
"traced_for=custom_dce, fun=my_g, arg_names=x",
])
def test_custom_dce_consts(self):
@ -1350,7 +1373,7 @@ class DebugInfoTest(jtu.JaxTestCase):
return np.eye(1) * jnp.sin(x), jnp.cos(x)
@my_f.def_dce
def rule(used_outs, y):
def my_rule(used_outs, y):
tracer_spy.append(y)
return (
np.full((1, 1), 2.0) * jnp.exp(y) if used_outs[0] else None,
@ -1367,8 +1390,8 @@ class DebugInfoTest(jtu.JaxTestCase):
'None'],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): no leaked tracer from my_rule?
"traced_for=custom_dce, fun=my_f, arg_names=x, from x",
# TODO(necula): rule.y does not have an inspected tracer?
])
def test_custom_linear_solve_complex(self):