mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #25876 from gnecula:debug_info_3
PiperOrigin-RevId: 715831527
This commit is contained in:
commit
70c1ee5d9c
@ -449,7 +449,7 @@ def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
|
|||||||
jaxpr = jaxpr_.jaxpr
|
jaxpr = jaxpr_.jaxpr
|
||||||
out_tree = lambda: tree_structure(out_shape)
|
out_tree = lambda: tree_structure(out_shape)
|
||||||
assert len(jaxpr.invars) == len(in_leaves)
|
assert len(jaxpr.invars) == len(in_leaves)
|
||||||
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
|
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
|
||||||
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
|
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
|
||||||
|
|
||||||
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
|
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
|
||||||
|
@ -653,9 +653,10 @@ def result_paths(_fun, _store, *args, **kwargs):
|
|||||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
|
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||||
result_paths: tuple[str, ...] | None = None,
|
trace_debug: TracingDebugInfo | None,
|
||||||
) -> core.Jaxpr:
|
result_paths: tuple[str, ...] | None = None,
|
||||||
|
) -> core.Jaxpr:
|
||||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||||
if trace_debug is None:
|
if trace_debug is None:
|
||||||
return jaxpr
|
return jaxpr
|
||||||
|
@ -834,7 +834,7 @@ def checkify_while_body_jaxpr(
|
|||||||
new_body_f_ = lu.wrap_init(new_body_f)
|
new_body_f_ = lu.wrap_init(new_body_f)
|
||||||
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
|
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
|
||||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
||||||
*body_jaxpr.in_avals])
|
*body_jaxpr.in_avals])
|
||||||
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
||||||
err_vals, err_tree = jtu.tree_flatten(error)
|
err_vals, err_tree = jtu.tree_flatten(error)
|
||||||
err_vals = map(core.get_aval, err_vals)
|
err_vals = map(core.get_aval, err_vals)
|
||||||
@ -1203,7 +1203,7 @@ def checkify(f: Callable[..., Out],
|
|||||||
closed_f = lambda: f(*args, **kwargs)
|
closed_f = lambda: f(*args, **kwargs)
|
||||||
# stage:
|
# stage:
|
||||||
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
|
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
|
||||||
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
|
debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
|
||||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||||
# checkify:
|
# checkify:
|
||||||
|
@ -150,7 +150,7 @@ class custom_vmap:
|
|||||||
args_flat, in_tree = tree_flatten(args)
|
args_flat, in_tree = tree_flatten(args)
|
||||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||||
in_avals = [core.get_aval(x) for x in args_flat]
|
in_avals = [core.get_aval(x) for x in args_flat]
|
||||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||||
|
@ -488,7 +488,7 @@ class custom_partitioning:
|
|||||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||||
in_avals = [core.get_aval(x) for x in args_flat]
|
in_avals = [core.get_aval(x) for x in args_flat]
|
||||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
|
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False,
|
||||||
"custom_partitioning")
|
"custom_partitioning")
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||||
|
@ -1909,7 +1909,7 @@ class DynamicJaxprTrace(core.Trace):
|
|||||||
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
||||||
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
|
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
|
||||||
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
|
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
|
||||||
dbg = debug_info_final(f, call_primitive.name)
|
dbg = tracing_debug_info_final(f, call_primitive.name)
|
||||||
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg)
|
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg)
|
||||||
if params.get('inline', False):
|
if params.get('inline', False):
|
||||||
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
|
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
|
||||||
@ -1946,7 +1946,7 @@ class DynamicJaxprTrace(core.Trace):
|
|||||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||||
f, reduced_in_avals,
|
f, reduced_in_avals,
|
||||||
debug_info=debug_info_final(f, map_primitive.name))
|
debug_info=tracing_debug_info_final(f, map_primitive.name))
|
||||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||||
if ordered_effects:
|
if ordered_effects:
|
||||||
raise ValueError("Ordered effects not supported for "
|
raise ValueError("Ordered effects not supported for "
|
||||||
@ -2108,14 +2108,15 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
|||||||
return [*out_primals, *out_nz_tangents]
|
return [*out_primals, *out_nz_tangents]
|
||||||
|
|
||||||
# Callers should be using linear_util.debug_info instead!
|
# Callers should be using linear_util.debug_info instead!
|
||||||
def debug_info(
|
def tracing_debug_info(
|
||||||
fn: Callable,
|
fn: Callable,
|
||||||
in_tree: PyTreeDef | None,
|
in_tree: PyTreeDef | None,
|
||||||
out_tree_thunk: Callable[[], PyTreeDef] | None,
|
out_tree_thunk: Callable[[], PyTreeDef] | None,
|
||||||
has_kwargs: bool,
|
has_kwargs: bool,
|
||||||
traced_for: str
|
traced_for: str
|
||||||
) -> lu.TracingDebugInfo | None:
|
) -> lu.TracingDebugInfo:
|
||||||
src_info = fun_sourceinfo(fn)
|
src_info = fun_sourceinfo(fn)
|
||||||
|
arg_names: tuple[str | None, ...] | None
|
||||||
try:
|
try:
|
||||||
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
|
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
|
||||||
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
|
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
|
||||||
@ -2123,19 +2124,20 @@ def debug_info(
|
|||||||
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
|
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
|
||||||
for path, _ in generate_key_paths(dummy))
|
for path, _ in generate_key_paths(dummy))
|
||||||
except:
|
except:
|
||||||
arg_names = None
|
arg_names = None # TODO(necula): we should not need this
|
||||||
def result_paths():
|
def result_paths():
|
||||||
try:
|
try:
|
||||||
out_tree = out_tree_thunk()
|
out_tree = out_tree_thunk()
|
||||||
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
|
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
|
||||||
except:
|
except:
|
||||||
return None
|
return None # TODO(necula): this does not seem to be needed
|
||||||
return tuple(path for path, _ in generate_key_paths(dummy_result))
|
return tuple(path for path, _ in generate_key_paths(dummy_result))
|
||||||
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore
|
# TODO(necula): clean up the type: ignore below
|
||||||
|
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type]
|
||||||
|
|
||||||
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
|
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
|
||||||
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
|
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
|
||||||
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
||||||
|
|
||||||
|
|
||||||
@profiler.annotate_function
|
@profiler.annotate_function
|
||||||
|
@ -721,8 +721,8 @@ def stage_parallel_callable(
|
|||||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||||
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
|
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||||
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
|
||||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
||||||
|
|
||||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||||
len(out_sharded_avals), len(pci.out_axes))
|
len(out_sharded_avals), len(pci.out_axes))
|
||||||
@ -879,8 +879,8 @@ def lower_parallel_callable(
|
|||||||
replicated_args=replicated_args,
|
replicated_args=replicated_args,
|
||||||
arg_shardings=None,
|
arg_shardings=None,
|
||||||
result_shardings=None,
|
result_shardings=None,
|
||||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
|
||||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
|
||||||
num_replicas=replicas.num_global_replicas,
|
num_replicas=replicas.num_global_replicas,
|
||||||
lowering_parameters=lowering_parameters)
|
lowering_parameters=lowering_parameters)
|
||||||
return PmapComputation(lowering_result.module,
|
return PmapComputation(lowering_result.module,
|
||||||
@ -891,7 +891,7 @@ def lower_parallel_callable(
|
|||||||
ordered_effects=ordered_effects,
|
ordered_effects=ordered_effects,
|
||||||
keepalive=lowering_result.keepalive,
|
keepalive=lowering_result.keepalive,
|
||||||
host_callbacks=lowering_result.host_callbacks,
|
host_callbacks=lowering_result.host_callbacks,
|
||||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
jaxpr_debug_info=closed_jaxpr.jaxpr._debug_info,
|
||||||
shape_poly_state=lowering_result.shape_poly_state)
|
shape_poly_state=lowering_result.shape_poly_state)
|
||||||
|
|
||||||
|
|
||||||
@ -1833,7 +1833,7 @@ def _move_mutable_consts(
|
|||||||
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
|
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
|
||||||
from jax._src.state.discharge import discharge_state
|
from jax._src.state.discharge import discharge_state
|
||||||
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
|
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
|
||||||
jaxpr_._debug_info = jaxpr.jaxpr.debug_info
|
jaxpr_._debug_info = jaxpr.jaxpr._debug_info
|
||||||
return core.ClosedJaxpr(jaxpr_, consts)
|
return core.ClosedJaxpr(jaxpr_, consts)
|
||||||
|
|
||||||
|
|
||||||
@ -1971,8 +1971,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
result_shardings=out_mlir_shardings,
|
result_shardings=out_mlir_shardings,
|
||||||
in_layouts=in_layouts,
|
in_layouts=in_layouts,
|
||||||
out_layouts=out_layouts,
|
out_layouts=out_layouts,
|
||||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
|
||||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
|
||||||
num_replicas=nreps,
|
num_replicas=nreps,
|
||||||
num_partitions=num_partitions,
|
num_partitions=num_partitions,
|
||||||
all_default_mem_kind=all_default_mem_kind,
|
all_default_mem_kind=all_default_mem_kind,
|
||||||
@ -2208,7 +2208,7 @@ def lower_sharding_computation(
|
|||||||
auto_spmd_lowering = check_if_any_auto(
|
auto_spmd_lowering = check_if_any_auto(
|
||||||
it.chain.from_iterable([in_shardings, out_shardings]))
|
it.chain.from_iterable([in_shardings, out_shardings]))
|
||||||
|
|
||||||
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr.debug_info)
|
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info)
|
||||||
|
|
||||||
closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr(
|
closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr(
|
||||||
closed_jaxpr, api_name, fun_name, keep_unused, donated_invars,
|
closed_jaxpr, api_name, fun_name, keep_unused, donated_invars,
|
||||||
|
@ -53,8 +53,8 @@ def _typecheck_param(prim, param, name, msg_required, pred):
|
|||||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||||
primitive_name: str | None = None):
|
primitive_name: str | None = None):
|
||||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||||
debug = pe.debug_info(fun, in_tree, out_tree, False,
|
debug = pe.tracing_debug_info(fun, in_tree, out_tree, False,
|
||||||
primitive_name or "<unknown>")
|
primitive_name or "<unknown>")
|
||||||
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||||
wrapped_fun, in_avals, debug)
|
wrapped_fun, in_avals, debug)
|
||||||
return jaxpr, consts, out_tree(), attrs_tracked
|
return jaxpr, consts, out_tree(), attrs_tracked
|
||||||
|
@ -139,7 +139,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
|||||||
ops_avals = tuple(map(core.get_aval, ops))
|
ops_avals = tuple(map(core.get_aval, ops))
|
||||||
|
|
||||||
if config.mutable_array_checks.value:
|
if config.mutable_array_checks.value:
|
||||||
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
|
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch')
|
||||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||||
|
|
||||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||||
@ -238,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
|||||||
ops_avals = tuple(map(core.get_aval, ops))
|
ops_avals = tuple(map(core.get_aval, ops))
|
||||||
|
|
||||||
if config.mutable_array_checks.value:
|
if config.mutable_array_checks.value:
|
||||||
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
|
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond')
|
||||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||||
|
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
|||||||
|
|
||||||
if config.mutable_array_checks.value:
|
if config.mutable_array_checks.value:
|
||||||
in_flat, in_tree = tree_flatten((init, xs))
|
in_flat, in_tree = tree_flatten((init, xs))
|
||||||
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
|
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan')
|
||||||
in_avals = tuple(_map(core.get_aval, in_flat))
|
in_avals = tuple(_map(core.get_aval, in_flat))
|
||||||
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
|
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
|
||||||
|
|
||||||
|
@ -657,7 +657,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
|
|||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
def _trace_composite_to_jaxpr(fun, in_tree, in_avals):
|
def _trace_composite_to_jaxpr(fun, in_tree, in_avals):
|
||||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||||
debug_info = pe.debug_info(fun, in_tree, out_tree, False, "composite")
|
debug_info = pe.tracing_debug_info(fun, in_tree, out_tree, False, "composite")
|
||||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
|
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
|
||||||
# TODO(danfm): support const inputs to composite.
|
# TODO(danfm): support const inputs to composite.
|
||||||
assert not consts
|
assert not consts
|
||||||
|
@ -258,7 +258,7 @@ class TracingDebugInfo(NamedTuple):
|
|||||||
# formed just before staging to a jaxpr and read in trace-time error messages.
|
# formed just before staging to a jaxpr and read in trace-time error messages.
|
||||||
traced_for: str # e.g. 'jit', 'scan', etc
|
traced_for: str # e.g. 'jit', 'scan', etc
|
||||||
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||||
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )
|
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
|
||||||
# e.g. ('[0]', '[1]', ...)
|
# e.g. ('[0]', '[1]', ...)
|
||||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None
|
result_paths_thunk: Callable[[], tuple[str, ...]] | None
|
||||||
|
|
||||||
|
@ -426,7 +426,7 @@ class BlockSpec:
|
|||||||
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
|
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
|
||||||
lu.wrap_init(index_map_func), index_map_tree
|
lu.wrap_init(index_map_func), index_map_tree
|
||||||
)
|
)
|
||||||
debug = pe.debug_info(
|
debug = pe.tracing_debug_info(
|
||||||
index_map_func,
|
index_map_func,
|
||||||
index_map_tree,
|
index_map_tree,
|
||||||
index_map_out_tree_thunk,
|
index_map_out_tree_thunk,
|
||||||
|
@ -1340,7 +1340,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
|||||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
||||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||||
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
||||||
debug = pe.debug_info(
|
debug = pe.tracing_debug_info(
|
||||||
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
|
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
|
||||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||||
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||||
@ -1410,7 +1410,7 @@ def _trace_kernel_to_jaxpr(
|
|||||||
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
||||||
wrapped_kernel_fun, kernel_in_transforms
|
wrapped_kernel_fun, kernel_in_transforms
|
||||||
)
|
)
|
||||||
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
debug = pe.tracing_debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
||||||
with grid_mapping.trace_env():
|
with grid_mapping.trace_env():
|
||||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||||
kernel_avals, debug)
|
kernel_avals, debug)
|
||||||
|
@ -47,11 +47,11 @@ from jax._src import tree_util
|
|||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src import xla_bridge as xb
|
from jax._src import xla_bridge as xb
|
||||||
from jax._src.api_util import (
|
from jax._src.api_util import (
|
||||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||||
donation_vector, check_callable, resolve_argnums,
|
donation_vector, check_callable, resolve_argnums,
|
||||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
|
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
|
||||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||||
_check_no_aliased_closed_over_refs)
|
_check_no_aliased_closed_over_refs)
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax._src.partition_spec import PartitionSpec
|
from jax._src.partition_spec import PartitionSpec
|
||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
|
|||||||
in_tree: PyTreeDef
|
in_tree: PyTreeDef
|
||||||
out_tree: PyTreeDef
|
out_tree: PyTreeDef
|
||||||
donated_invars: tuple[bool, ...]
|
donated_invars: tuple[bool, ...]
|
||||||
arg_names: tuple[str, ...] | None
|
arg_names: tuple[str | None, ...] | None
|
||||||
num_consts: int
|
num_consts: int
|
||||||
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
|
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
|
||||||
abstract_mesh: AbstractMesh
|
abstract_mesh: AbstractMesh
|
||||||
@ -637,7 +637,7 @@ def _infer_params_impl(
|
|||||||
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||||
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
|
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
|
||||||
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
|
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
|
||||||
tuple(out_avals), jaxpr.jaxpr.debug_info, device_or_backend_set)
|
tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set)
|
||||||
|
|
||||||
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
|
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
|
||||||
|
|
||||||
@ -1301,7 +1301,7 @@ def _create_pjit_jaxpr(
|
|||||||
with dispatch.log_elapsed_time(
|
with dispatch.log_elapsed_time(
|
||||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
||||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||||
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
|
||||||
if config.dynamic_shapes.value:
|
if config.dynamic_shapes.value:
|
||||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||||
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
|
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
|
||||||
@ -1313,7 +1313,7 @@ def _create_pjit_jaxpr(
|
|||||||
|
|
||||||
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
||||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||||
|
|
||||||
if config.debug_key_reuse.value:
|
if config.debug_key_reuse.value:
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
@ -1672,7 +1672,7 @@ def _pjit_call_impl_python(
|
|||||||
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||||
pxla.check_array_xla_sharding_layout_match(
|
pxla.check_array_xla_sharding_layout_match(
|
||||||
args, compiled._in_shardings, compiled._in_layouts,
|
args, compiled._in_shardings, compiled._in_layouts,
|
||||||
jaxpr.jaxpr.debug_info, compiled._kept_var_idx)
|
jaxpr.jaxpr.tracing_debug_info, compiled._kept_var_idx)
|
||||||
if config.distributed_debug.value:
|
if config.distributed_debug.value:
|
||||||
# Defensively only perform fingerprint logic if debug logging is enabled
|
# Defensively only perform fingerprint logic if debug logging is enabled
|
||||||
# NOTE(skyewm): I didn't benchmark this
|
# NOTE(skyewm): I didn't benchmark this
|
||||||
|
@ -846,7 +846,7 @@ def _run_state_partial_eval_custom(
|
|||||||
out = run_state_p.bind(*args, **staged_params)
|
out = run_state_p.bind(*args, **staged_params)
|
||||||
return out[num_res:]
|
return out[num_res:]
|
||||||
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
|
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
|
||||||
[v.aval for v in res_staged_invars])
|
[v.aval for v in res_staged_invars])
|
||||||
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
|
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
|
||||||
staged_outvars,
|
staged_outvars,
|
||||||
core.closed_call_p,
|
core.closed_call_p,
|
||||||
@ -993,7 +993,7 @@ def initial_style_jaxpr(
|
|||||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||||
tree_util.treedef_tuple((in_tree,)))
|
tree_util.treedef_tuple((in_tree,)))
|
||||||
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
debug = pe.tracing_debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
||||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||||
return jaxpr, consts, out_tree_thunk()
|
return jaxpr, consts, out_tree_thunk()
|
||||||
|
|
||||||
|
@ -57,8 +57,8 @@ from jax._src.interpreters.partial_eval import (
|
|||||||
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
|
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
|
||||||
dce_jaxpr_consts as dce_jaxpr_consts,
|
dce_jaxpr_consts as dce_jaxpr_consts,
|
||||||
dce_rules as dce_rules,
|
dce_rules as dce_rules,
|
||||||
debug_info as debug_info,
|
tracing_debug_info as tracing_debug_info,
|
||||||
debug_info_final as debug_info_final,
|
tracing_debug_info_final as tracing_debug_info_final,
|
||||||
def_trivial_padding as def_trivial_padding,
|
def_trivial_padding as def_trivial_padding,
|
||||||
forwarding_rules as forwarding_rules,
|
forwarding_rules as forwarding_rules,
|
||||||
has_effects as has_effects,
|
has_effects as has_effects,
|
||||||
|
@ -545,7 +545,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return [a + 1]
|
return [a + 1]
|
||||||
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||||
self.assertLen(discharged_jaxpr.invars, 1)
|
self.assertLen(discharged_jaxpr.invars, 1)
|
||||||
@ -561,7 +561,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return [a + 1]
|
return [a + 1]
|
||||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||||
self.assertLen(discharged_jaxpr.invars, 1)
|
self.assertLen(discharged_jaxpr.invars, 1)
|
||||||
@ -595,7 +595,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||||
core.ShapedArray((), jnp.dtype('float32'))]
|
core.ShapedArray((), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
# Discharging should just turn this into a jaxpr that ignores the first
|
# Discharging should just turn this into a jaxpr that ignores the first
|
||||||
# value and returns second value plus 1.
|
# value and returns second value plus 1.
|
||||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||||
@ -612,7 +612,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return []
|
return []
|
||||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||||
self.assertLen(discharged_jaxpr.invars, 1)
|
self.assertLen(discharged_jaxpr.invars, 1)
|
||||||
@ -632,7 +632,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return []
|
return []
|
||||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
discharged_jaxpr, discharged_consts = discharge_state(
|
discharged_jaxpr, discharged_consts = discharge_state(
|
||||||
stateful_jaxpr, consts)
|
stateful_jaxpr, consts)
|
||||||
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
||||||
@ -666,7 +666,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||||
core.ShapedArray((), jnp.dtype('float32'))]
|
core.ShapedArray((), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
# Discharging should just turn this into a jaxpr that adds the first value,
|
# Discharging should just turn this into a jaxpr that adds the first value,
|
||||||
# second value, and 1.
|
# second value, and 1.
|
||||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||||
@ -684,7 +684,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return []
|
return []
|
||||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||||
self.assertLen(discharged_jaxpr.invars, 1)
|
self.assertLen(discharged_jaxpr.invars, 1)
|
||||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||||
@ -705,7 +705,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return []
|
return []
|
||||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
discharged_jaxpr, discharged_consts = discharge_state(
|
discharged_jaxpr, discharged_consts = discharge_state(
|
||||||
stateful_jaxpr, consts)
|
stateful_jaxpr, consts)
|
||||||
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
||||||
@ -719,7 +719,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
return [a, b]
|
return [a, b]
|
||||||
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
|
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||||
self.assertLen(discharged_jaxpr.invars, 1)
|
self.assertLen(discharged_jaxpr.invars, 1)
|
||||||
self.assertLen(discharged_jaxpr.outvars, 3)
|
self.assertLen(discharged_jaxpr.outvars, 3)
|
||||||
@ -739,7 +739,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
|||||||
shaped_array_ref((4,), jnp.dtype('float32'))
|
shaped_array_ref((4,), jnp.dtype('float32'))
|
||||||
]
|
]
|
||||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||||
in_avals)
|
in_avals)
|
||||||
discharged_jaxpr, _ = discharge_state(
|
discharged_jaxpr, _ = discharge_state(
|
||||||
stateful_jaxpr, consts, should_discharge=[False, True])
|
stateful_jaxpr, consts, should_discharge=[False, True])
|
||||||
self.assertLen(discharged_jaxpr.invars, 2)
|
self.assertLen(discharged_jaxpr.invars, 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user