mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[better_errors] Refactor more uses of pe.tracing_debug_info (part 3)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`, which uses the actual args and kwargs, instead of `in_tree` to manufacture fake args and kwargs. This ends up being more accurate, especially for `arg_names`; see changes in debug_info_tests.py. This means that we have to construct the debug info further upstream, before flattening args. This will later help populate debug info in `WrappedFun` and `Jaxpr`. This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify, and the custom_partitioning (the last few uses). In order to land this, I had to remove a safety check that the number of `arg_names` and `result_paths` in a Jaxpr's debug info match the number of Jaxpr invars and outvars, respectively. Additionally, I added two accessors `safe_arg_names` and `safe_result_paths` to ensure that the arg names and result paths match the expected length. These accessors return no-op results when the lengths are not as expected. From my testint, this happens only in Jaxprs that are not used for lowering, hence there is no actual user-visible change here. Simply, more internal Jaxprs are getting debug_info and in some cases the `arg_names` and `result_paths` are not correct. Still, this change is worth it because the `func_src_info` is the most useful part of the debug info (used for leaked tracers), and that is accurate. We will fix the `arg_names` and `result_paths` in a future change. One can see in the changes in debug_info_test.py the improvements in the user-visible debug info, including for `pjit` and `pmap` cases when it was wrong.
This commit is contained in:
parent
d223dfc3f7
commit
32c98b9a76
@ -61,7 +61,7 @@ from jax._src.api_util import (
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
|
||||
result_paths, flat_out_axes, debug_info_final)
|
||||
result_paths, flat_out_axes)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -1420,14 +1420,14 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
|
||||
return global_axis_size
|
||||
|
||||
|
||||
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donate_tuple, in_devices, backend_name,
|
||||
axis_size, args, kwargs):
|
||||
if in_devices is not None and len(in_devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
|
||||
dbg = tracing_debug_info(
|
||||
'pmap', fun, args, kwargs,
|
||||
"pmap", fun, args, kwargs,
|
||||
static_argnums=static_broadcasted_tuple)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
@ -1478,9 +1478,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
|
||||
|
||||
f, res_paths = result_paths(f)
|
||||
dbg = dbg.add_result_paths(res_paths)
|
||||
f = lu.add_debug_info(f, dbg)
|
||||
f, out_axes_thunk = flat_out_axes(f, out_axes)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
flat_fun = debug_info_final(flat_fun, dbg, res_paths)
|
||||
|
||||
is_explicit_global_axis_size = axis_size is not None
|
||||
global_axis_size = _get_global_axis_size(local_axis_size, in_devices,
|
||||
|
@ -100,29 +100,6 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
def flattened_fun_in_tree(
|
||||
fn: lu.WrappedFun
|
||||
) -> tuple[PyTreeDef, Callable[[], PyTreeDef], bool] | None:
|
||||
# This implementation relies on internal details of linear_util.py's
|
||||
# WrappedFun, but it's for the worthy cause of better user error messages.
|
||||
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
||||
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
|
||||
# core.eval_jaxpr encounters a call primitive (though at that point we're just
|
||||
# round-tripping jaxprs and the user errors in question are impossible).
|
||||
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
|
||||
assert (isinstance(flatten_fun_nokwargs, partial) and
|
||||
len(flatten_fun_nokwargs.args) == 1)
|
||||
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
||||
try:
|
||||
((in_tree,), out_tree_store, has_kwargs), = (
|
||||
(args, store, f is flatten_fun.args[0])
|
||||
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
|
||||
except ValueError:
|
||||
# When `fn` is not the result of flatten_fun or flatten_fun_nokwargs
|
||||
return None
|
||||
else:
|
||||
return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr]
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
@ -705,6 +682,7 @@ def result_paths(_fun, _store, *args, **kwargs):
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
|
||||
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
trace_debug: TracingDebugInfo | None,
|
||||
result_paths: tuple[str, ...] | None = None,
|
||||
@ -712,7 +690,8 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
if trace_debug is None:
|
||||
return jaxpr
|
||||
assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
|
||||
# TODO(necula): re-enable this safety check
|
||||
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
|
||||
if result_paths is None:
|
||||
result_paths = trace_debug.result_paths_thunk() # type: ignore
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
@ -720,16 +699,6 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
trace_debug.arg_names, tuple(result_paths)) # type: ignore
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
|
||||
def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
|
||||
res_paths_thunk: Callable[[], tuple[str, ...]]
|
||||
) -> lu.WrappedFun:
|
||||
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
|
||||
if dbg is None: return f
|
||||
assert dbg.result_paths_thunk is None
|
||||
res_paths_thunk_ = HashableFunction(res_paths_thunk, closure=())
|
||||
return lu.add_debug_info(f, dbg._replace(result_paths_thunk=res_paths_thunk_))
|
||||
|
||||
|
||||
def hoist_obj_attrs(f, flat_args):
|
||||
idxs, objs, flat_args_ = [], [], []
|
||||
for i, x in enumerate(flat_args):
|
||||
|
@ -27,6 +27,7 @@ from jax import lax
|
||||
|
||||
from jax.experimental import shard_map
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import config
|
||||
@ -39,7 +40,6 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util as jtu
|
||||
from jax._src.ad_util import SymbolicZero
|
||||
from jax._src.api_util import flatten_fun
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -1202,8 +1202,10 @@ def checkify(f: Callable[..., Out],
|
||||
in_tree = jtu.tree_structure(((), {}))
|
||||
closed_f = lambda: f(*args, **kwargs)
|
||||
# stage:
|
||||
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
|
||||
debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
|
||||
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
|
||||
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
|
||||
debug_info=debug),
|
||||
in_tree)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||
# checkify:
|
||||
|
@ -91,6 +91,22 @@ class JaxprDebugInfo(NamedTuple):
|
||||
# This is formed after tracing, when we have concrete `result_paths`
|
||||
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
if len(self.arg_names) == expected:
|
||||
return self.arg_names
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the result_paths with a safety check."""
|
||||
if len(self.result_paths) == expected:
|
||||
return self.result_paths
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return ("",) * expected
|
||||
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
|
||||
@ -149,8 +165,9 @@ class Jaxpr:
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info
|
||||
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
# TODO(necula): re-enable these safety checks
|
||||
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.pretty_print())
|
||||
|
@ -473,6 +473,9 @@ class custom_partitioning:
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
|
||||
args, kwargs,
|
||||
static_argnums=self.static_argnums)
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
|
||||
@ -491,8 +494,6 @@ class custom_partitioning:
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False,
|
||||
"custom_partitioning")
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
|
@ -35,7 +35,6 @@ from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
from jax._src import compute_on
|
||||
from jax._src import xla_metadata as xla_metadata_lib
|
||||
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs)
|
||||
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
|
||||
Var, DropVar, Atom,
|
||||
@ -44,7 +43,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
@ -1913,8 +1912,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
||||
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
|
||||
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
|
||||
dbg = tracing_debug_info_final(f, call_primitive.name)
|
||||
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg)
|
||||
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
|
||||
if params.get('inline', False):
|
||||
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
|
||||
propagate_source_info=False)
|
||||
@ -1940,7 +1938,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
|
||||
|
||||
def process_map(self, map_primitive, f, tracers, params):
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun,
|
||||
tracers: Sequence[core.Tracer], params):
|
||||
tracers = map(self.to_jaxpr_tracer, tracers)
|
||||
in_avals = [t.aval for t in tracers]
|
||||
axis_name, axis_size = params['axis_name'], params['axis_size']
|
||||
@ -1949,8 +1948,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||
f, reduced_in_avals,
|
||||
debug_info=tracing_debug_info_final(f, map_primitive.name))
|
||||
f, reduced_in_avals, f.debug_info)
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -2050,7 +2048,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
closed_call_jaxpr = core.ClosedJaxpr(
|
||||
convert_constvars_jaxpr(call_jaxpr), ())
|
||||
|
||||
transpose_flat, in_tree2 = flatten_fun_nokwargs(
|
||||
transpose_flat, in_tree2 = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
|
||||
|
||||
# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
|
||||
@ -2111,42 +2109,6 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
store.store(out_zeros)
|
||||
return [*out_primals, *out_nz_tangents]
|
||||
|
||||
# Callers should be using linear_util.debug_info instead!
|
||||
def tracing_debug_info(
|
||||
fn: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
out_tree_thunk: Callable[[], PyTreeDef],
|
||||
has_kwargs: bool,
|
||||
traced_for: str
|
||||
) -> lu.TracingDebugInfo:
|
||||
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
|
||||
# We just have to make sure we grad the debugging information when we have
|
||||
# the unflattened args
|
||||
# TODO(necula): in general we can just pretend the leaves are booleans, but
|
||||
# when we use custom pytrees, the flattening functions may check the type
|
||||
# of the argument
|
||||
try:
|
||||
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
|
||||
except:
|
||||
# TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit
|
||||
dummy_args = ([False], {}) if has_kwargs else [False]
|
||||
args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore
|
||||
def res_paths_thunk() -> tuple[str, ...]:
|
||||
out_tree = out_tree_thunk()
|
||||
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
|
||||
return tuple(tree_util.keystr(path)
|
||||
for path, _ in tree_util.generate_key_paths(dummy_result))
|
||||
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
|
||||
result_paths_thunk=res_paths_thunk)
|
||||
|
||||
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
|
||||
fn_trees = flattened_fun_in_tree(fn)
|
||||
if fn_trees is None:
|
||||
# TODO(necula): eliminate this branch
|
||||
return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f),
|
||||
(None,), None)
|
||||
in_tree, out_tree_thunk, has_kws = fn_trees
|
||||
return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for)
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic(
|
||||
|
@ -652,6 +652,7 @@ class ParallelCallableInfo:
|
||||
in_axes: Iterable[int | None]
|
||||
out_axes_thunk: Callable[[], Sequence[int | None]]
|
||||
avals: Sequence[core.AbstractValue]
|
||||
debug_info: api_util.TracingDebugInfo | None
|
||||
|
||||
@cached_property
|
||||
def local_devices(self):
|
||||
@ -722,8 +723,8 @@ def stage_parallel_callable(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||
fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
||||
fun, sharded_avals, pci.debug_info)
|
||||
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
|
||||
|
||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||
len(out_sharded_avals), len(pci.out_axes))
|
||||
@ -757,7 +758,7 @@ def get_pmap_jaxpr(
|
||||
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
in_axes, out_axes_thunk, avals, fun.debug_info)
|
||||
with core.extend_axis_env_nd([(axis_name, axis_size)]):
|
||||
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
||||
@ -880,8 +881,8 @@ def lower_parallel_callable(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return PmapComputation(lowering_result.module,
|
||||
@ -1971,8 +1972,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
result_shardings=out_mlir_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
|
@ -71,7 +71,7 @@ import weakref
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.util import curry, cache_clearing_funs
|
||||
from jax._src.util import curry, cache_clearing_funs, HashableFunction
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -175,11 +175,11 @@ class WrappedFun:
|
||||
if out_store is None:
|
||||
return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args),
|
||||
((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None, None)
|
||||
(out_store,) + self.stores, self.params, None, self.debug_info)
|
||||
else:
|
||||
return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args),
|
||||
((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None, None)
|
||||
(out_store,) + self.stores, self.params, None, self.debug_info)
|
||||
|
||||
def populate_stores(self, stores):
|
||||
"""Copy the values from the `stores` into `self.stores`."""
|
||||
@ -282,6 +282,21 @@ class TracingDebugInfo(NamedTuple):
|
||||
jaxpr_dbg.arg_names,
|
||||
lambda: jaxpr_dbg.result_paths)
|
||||
|
||||
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
|
||||
) -> TracingDebugInfo:
|
||||
assert self.result_paths_thunk is None
|
||||
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
|
||||
closure=()))
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
if len(self.arg_names) == expected:
|
||||
return self.arg_names
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
|
||||
def wrap_init(f: Callable, params=None, *,
|
||||
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
|
@ -569,6 +569,7 @@ def _infer_params_impl(
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, res_paths = result_paths(f)
|
||||
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
|
||||
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
|
||||
del args
|
||||
|
||||
@ -1160,7 +1161,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
||||
None if debug_info is None else debug_info.arg_names,
|
||||
None if debug_info is None else debug_info.safe_arg_names(len(in_avals)),
|
||||
"pjit arguments", allow_uneven_sharding=False)
|
||||
check_aval_layout_compatibility(
|
||||
in_layouts_flat, in_avals,
|
||||
@ -1302,7 +1303,7 @@ def _create_pjit_jaxpr(
|
||||
in_type: core.InputType | Sequence[core.AbstractValue],
|
||||
attr_data: int,
|
||||
debug_info: lu.TracingDebugInfo,
|
||||
out_paths: Callable,
|
||||
result_paths: Callable,
|
||||
ignored_inline: IgnoreKey
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
@ -1314,19 +1315,18 @@ def _create_pjit_jaxpr(
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
|
||||
if config.dynamic_shapes.value:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
|
||||
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
|
||||
attrs_tracked = []
|
||||
else:
|
||||
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_type, debug_info=pe_debug)
|
||||
fun, in_type, debug_info=debug_info)
|
||||
# assert attr_data is sentinel or attr_data matches attrs_tracked
|
||||
|
||||
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
|
||||
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
@ -1366,11 +1366,12 @@ def _check_and_canonicalize_out_shardings(
|
||||
if not config.dynamic_shapes.value:
|
||||
pjit_check_aval_sharding(
|
||||
out_shardings_flat, out_avals,
|
||||
None if debug_info is None else debug_info.result_paths,
|
||||
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
|
||||
"pjit outputs", allow_uneven_sharding=False)
|
||||
check_aval_layout_compatibility(
|
||||
out_layouts_flat, out_avals,
|
||||
None if debug_info is None else debug_info.result_paths, "jit outputs")
|
||||
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
|
||||
"jit outputs")
|
||||
return out_shardings_flat, out_layouts_flat
|
||||
|
||||
|
||||
@ -1423,9 +1424,9 @@ class IgnoreKey:
|
||||
|
||||
|
||||
def pjit_check_aval_sharding(
|
||||
shardings, flat_avals, names: tuple[str, ...] | None,
|
||||
shardings, flat_avals, names: tuple[str | None, ...] | None,
|
||||
what_aval: str, allow_uneven_sharding: bool):
|
||||
new_names = [''] * len(shardings) if names is None else names
|
||||
new_names = [None] * len(shardings) if names is None else names
|
||||
for aval, s, name in zip(flat_avals, shardings, new_names):
|
||||
if isinstance(s, (UnspecifiedValue, AUTO)):
|
||||
continue
|
||||
|
@ -57,8 +57,6 @@ from jax._src.interpreters.partial_eval import (
|
||||
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
|
||||
dce_jaxpr_consts as dce_jaxpr_consts,
|
||||
dce_rules as dce_rules,
|
||||
tracing_debug_info as tracing_debug_info,
|
||||
tracing_debug_info_final as tracing_debug_info_final,
|
||||
def_trivial_padding as def_trivial_padding,
|
||||
forwarding_rules as forwarding_rules,
|
||||
has_effects as has_effects,
|
||||
|
@ -639,8 +639,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=my_g, arg_names=b, from b",
|
||||
# TODO(necula): bad arg name
|
||||
"traced_for=jit, fun=my_f, arg_names=args[0], from args[0]"
|
||||
"traced_for=jit, fun=my_f, arg_names=a, from a",
|
||||
])
|
||||
|
||||
def test_jit_arg_names(self):
|
||||
@ -692,8 +691,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): the arg_names are not right, include static ones, also missing args[1]
|
||||
"traced_for=jit, fun=my_f, arg_names=x['hi'],y,z,args[0],kwargs['t'],kwargs['w'], from kwargs['w']",
|
||||
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
|
||||
"None", # TODO(necula)
|
||||
],
|
||||
expected_lowering_lines=[
|
||||
@ -1177,8 +1175,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): bad tracer provenance
|
||||
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],a,kwargs['b'],kwargs['d'], from args[0]",
|
||||
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1219,6 +1216,31 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): why this?
|
||||
re.compile(r'traced_for=jit, fun=_multi_slice at .*/array_methods.py:.*, arg_names=self, result_paths=.*'),
|
||||
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']",
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): missing debug_info
|
||||
'None'
|
||||
],
|
||||
)
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message=".* jitted function .* includes a pmap")
|
||||
def test_jvp_pmap(self):
|
||||
tracer_spy = TracerSpy()
|
||||
def my_f(x, y):
|
||||
tracer_spy.append(x)
|
||||
return jnp.sin(x) + y
|
||||
|
||||
x = np.ones((jax.device_count(), 1), dtype=np.float32)
|
||||
x_tan = np.full_like(x, .1)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))),
|
||||
x, x_tan,
|
||||
expected_jaxpr_debug_infos=[
|
||||
'traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=[0],[1]',
|
||||
"None", # TODO(necula): missing debug info
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
@ -1307,7 +1329,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"None", # TODO(necula): missing tracer debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=xla_pmap, fun=my_f, arg_names=my_x",
|
||||
"traced_for=pmap, fun=my_f, arg_names=my_x",
|
||||
],
|
||||
check_lowering=False, # TODO(necula): warning during lowering
|
||||
)
|
||||
@ -1339,7 +1361,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_dce, fun=my_g, arg_names=x"
|
||||
# TODO(necula): no leaked tracer from my_g_dce?
|
||||
"traced_for=custom_dce, fun=my_g, arg_names=x",
|
||||
])
|
||||
|
||||
def test_custom_dce_consts(self):
|
||||
@ -1350,7 +1373,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
return np.eye(1) * jnp.sin(x), jnp.cos(x)
|
||||
|
||||
@my_f.def_dce
|
||||
def rule(used_outs, y):
|
||||
def my_rule(used_outs, y):
|
||||
tracer_spy.append(y)
|
||||
return (
|
||||
np.full((1, 1), 2.0) * jnp.exp(y) if used_outs[0] else None,
|
||||
@ -1367,8 +1390,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
'None'],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): no leaked tracer from my_rule?
|
||||
"traced_for=custom_dce, fun=my_f, arg_names=x, from x",
|
||||
# TODO(necula): rule.y does not have an inspected tracer?
|
||||
])
|
||||
|
||||
def test_custom_linear_solve_complex(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user