Merge pull request #26399 from gnecula:debug_info_jaxpr_6

PiperOrigin-RevId: 725563435
This commit is contained in:
jax authors 2025-02-11 03:39:39 -08:00
commit 49eccd6c60
20 changed files with 215 additions and 113 deletions

View File

@ -27,7 +27,7 @@ from jax._src import dtypes
from jax._src.state.types import AbstractRef
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
treedef_children, generate_key_paths, keystr, broadcast_prefix,
treedef_children, generate_key_paths, broadcast_prefix,
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
@ -595,6 +595,15 @@ def debug_info(
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> core.DebugInfo:
"""Constructd core.DebugInfo for a function given example args and kwargs.
`args` and `kwargs` are example positional and keyword arguments, users with
`inspect.Signature` to get the names of argments. The arguments that are
considered static for tracing purposes should be included, and designated
using `static_argnums` and `static_argnames`.
See docstring for linear_util.DebugInfo.
"""
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
@ -671,12 +680,13 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
return tuple(f'{name}{lu._clean_keystr_arg_names(path)}'
for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names

View File

@ -833,7 +833,7 @@ def checkify_while_body_jaxpr(
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
new_body_f_ = lu.wrap_init(new_body_f)
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
*body_jaxpr.in_avals])
@ -952,7 +952,8 @@ error_checks[ad_checkpoint.remat_p] = remat_error_check
def shard_map_error_check(
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
error: Error, enabled_errors, *vals_in,
jaxpr: core.Jaxpr, in_names, out_names, **kwargs
):
if (mesh := kwargs.get('mesh')) is None:
raise ValueError('Mesh must be provided for shard_map with checkify.')
@ -976,7 +977,6 @@ def shard_map_error_check(
)
num_out_error_vals = out_tree.num_leaves - len(out_names)
@lu.wrap_init
def expand_errors_leading_dim(*xs):
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
errs, outs = split_list(outs, [num_out_error_vals])
@ -985,7 +985,9 @@ def shard_map_error_check(
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
expand_errors_leading_dim, checked_jaxpr.in_avals
lu.wrap_init(expand_errors_leading_dim,
debug_info=checked_jaxpr.jaxpr.debug_info),
checked_jaxpr.in_avals
)
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
@ -993,7 +995,8 @@ def shard_map_error_check(
# Use fully sharded partitioning for out errors.
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
subfun = lu.hashable_partial(
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
checked_jaxpr.jaxpr, checked_jaxpr.consts
)
new_params = dict(
jaxpr=checked_jaxpr.jaxpr,
@ -1007,8 +1010,10 @@ def shard_map_error_check(
return tree_unflatten(out_tree, err_and_out)
error_checks[shard_map.shard_map_p] = shard_map_error_check
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
jvp_jaxpr_thunk, call_jaxpr, **params):
def custom_jvp_call_rule(in_err: Error,
enabled_errors: set, *in_vals, num_consts,
jvp_jaxpr_fun: lu.WrappedFun,
call_jaxpr: core.ClosedJaxpr, **params):
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
@ -1021,10 +1026,11 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
err_vals, err_tree = jtu.tree_flatten(in_err)
partial_checkify = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
call_jaxpr.consts, enabled_errors, err_tree))
call_jaxpr.consts, enabled_errors, err_tree),
debug_info=call_jaxpr.jaxpr.debug_info)
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
jvp, jvp_out_tree = flatten_fun_output(jvp)
all_outs = custom_derivatives.custom_jvp_call_p.bind(
partial_checkify, jvp, *err_vals, *in_vals, **params)
@ -1041,17 +1047,17 @@ error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
# outputs that checkify adds (just forwarding the error data's primal and
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
# Adding another layer of lu.transformation was tricky, though maybe doable.
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
@lu.wrap_init
def lift_jvp(num_errs: int, num_consts: int,
jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
@ -1063,7 +1069,7 @@ def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
primal_errs = xs[num_consts:num_consts+num_errs]
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
return jvp
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
fun_jaxpr: core.ClosedJaxpr,

View File

@ -378,20 +378,19 @@ class CustomJVPCallPrimitive(core.Primitive):
new_params = dict(params)
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
num_consts: int = new_params.pop('num_consts')
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
jvp_jaxpr_fun = new_params.pop('jvp_jaxpr_fun')
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
debug_info=call_jaxpr.jaxpr.debug_info)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
return [fun, jvp], new_params
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
debug_info: core.DebugInfo | None) -> lu.WrappedFun:
def lift_jvp(num_consts: int, jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
@ -401,16 +400,16 @@ def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
return [*out_primals, *out_tangents]
return lu.wrap_init(jvp, debug_info=debug_info)
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
num_consts, symbolic_zeros):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
del in_avals, jvp_jaxpr_fun, num_consts
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
@ -418,9 +417,9 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
return call_jaxpr.out_avals, call_jaxpr.effects
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun,
num_consts, symbolic_zeros):
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
del jvp_jaxpr_fun, num_consts, symbolic_zeros
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.name_stack, ctx.tokens_in, consts,
@ -452,7 +451,7 @@ def _custom_jvp_call_dce(
return [False] * len(eqn.invars), None
call_jaxpr = eqn.params["call_jaxpr"]
jvp_jaxpr_thunk = eqn.params["jvp_jaxpr_thunk"]
jvp_jaxpr_fun = eqn.params["jvp_jaxpr_fun"]
# We must set instantiate=True because some inputs that are unused by the
# DCE'ed primal might be used in the JVP rule.
dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate(
@ -461,7 +460,7 @@ def _custom_jvp_call_dce(
@pe._memoize
def dce_jvp_jaxpr_thunk(*in_zeros):
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_thunk(*in_zeros)
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*in_zeros)
dce_jvp_jaxpr, _ = pe.dce_jaxpr(jvp_jaxpr, [*used_outs, *used_outs], True)
dce_out_zeros = [v for used, v in zip(used_outs, out_zeros) if used]
return dce_jvp_jaxpr, consts, dce_out_zeros
@ -470,7 +469,8 @@ def _custom_jvp_call_dce(
new_params = dict(
eqn.params,
call_jaxpr=dce_call_jaxpr,
jvp_jaxpr_thunk=dce_jvp_jaxpr_thunk,
jvp_jaxpr_fun=lu.wrap_init(dce_jvp_jaxpr_thunk,
debug_info=jvp_jaxpr_fun.debug_info)
)
new_eqn = pe.new_jaxpr_eqn(
eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects,

View File

@ -84,9 +84,12 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
return out_primals, out_tangents
@lu.transformation_with_aux2
def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
nzs_in: Sequence[bool],
debug_info: core.DebugInfo | None,
*primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(None) # TODO(necula): fill-in the debug info (use JAX_USE_DIRECT_LINEARIZATION=1)
tangent_trace = pe.DynamicJaxprTrace(debug_info)
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p,
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
@ -99,7 +102,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
@ -678,11 +681,13 @@ class LinearizeTrace(Trace):
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
def process_call(self, call_primitive, f, tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun,
tracers, params):
assert call_primitive.multiple_results
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not Zero for t in tangents)
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in)
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in,
f.debug_info)
if isinstance(call_primitive, core.MapPrimitive):
@as_hashable_function(closure=(linearize_outs_thunk))
def new_out_axes_thunk():
@ -725,7 +730,9 @@ class LinearizeTrace(Trace):
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
self.tangent_trace, (lu.wrap_init(f_tangent,
debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), new_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]

View File

@ -502,7 +502,7 @@ def _closed_call_param_updater(params, _, __):
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params, debug_info=debug_info), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
@ -1992,7 +1992,9 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return out_tracers
def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros):
def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
jvp: lu.WrappedFun, tracers,
symbolic_zeros: bool):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
in_tangent_avals = [t.to_tangent_aval() for t in in_avals]
@ -2014,7 +2016,8 @@ class DynamicJaxprTrace(core.Trace):
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk,
debug_info=jvp.debug_info),
num_consts=len(consts),
symbolic_zeros=symbolic_zeros),
fun_jaxpr.effects,

View File

@ -1255,7 +1255,9 @@ def _scan_state_partial_discharge_rule(should_discharge, in_avals, out_avals, *a
)
# TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't
# forget to manage the effects.
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(wrapped, debug_info=discharged_jaxpr.debug_info),
avals_for_wrapped_no_refs)
all_out = scan_p.bind(*args_for_wrapped,
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
length=length,
@ -1922,9 +1924,9 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
carry, refs_out = split_list(carry_refs, [num_carry])
return [*refs_out, *carry]
new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
in ref_avals],
*carry_avals])
lu.wrap_init(new_body, debug_info=discharged_body_jaxpr.debug_info),
[*remaining_body_const_avals, *[a.inner_aval for a in ref_avals],
*carry_avals])
if new_body_consts: raise NotImplementedError
# Since some `Ref`s that were previously consts are now carries, we need to
@ -1936,9 +1938,8 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
del refs # We don't use them here!
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_cond), [*cond_consts_avals,
*[a.inner_aval for a in ref_avals],
*carry_avals])
lu.wrap_init(new_cond, debug_info=cond_jaxpr.debug_info),
[*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals])
if new_cond_consts: raise NotImplementedError
out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry,

View File

@ -159,7 +159,8 @@ def _root_jvp(const_lengths, jaxprs, primals, tangents):
linearize_and_solve = partial(
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
f_at_solution = lambda *params: f(*params, *solution)
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
_, rhs = ad.jvp(lu.wrap_init(f_at_solution,
debug_info=jaxprs.f.jaxpr.debug_info)).call_wrapped(
params.f, params_dot.f)
solution_dot = _map(
operator.neg, linearize_and_solve(*solution, *rhs))

View File

@ -65,13 +65,14 @@ from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
import re
from typing import Any, NamedTuple
import weakref
from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import keystr, generate_key_paths
from jax._src.tree_util import keystr, KeyPath, generate_key_paths
from jax._src.util import curry, cache_clearing_funs, HashableFunction
@ -266,13 +267,17 @@ class DebugInfo(NamedTuple):
# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).
# Uses `None` for the args that do not correspond to user-named arguments,
# e.g., tangent args in jax.jvp.
# e.g., tangent args in jax.jvp. At the moment, `arg_names` accuracy is
# best-effort. Use `safe_arg_names` to detect and handle an unexpected
# number of elements in `arg_names`.
arg_names: tuple[str | None, ...]
# The result paths are not available while we are tracing the function,
# instead we keep a thunk. Once we are done tracing, we use
# `self.resolve_result_paths()` to execute the thunk and replace the
# actual result paths.
# actual result paths. At the moment, `result_paths` accuracy is
# best-effort. Use `safe_result_paths` to detect and handle an unexpected
# number of elements in `result_paths`.
# e.g. ('[0]', '[1]', ...)
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
@ -329,10 +334,16 @@ def wrap_init(f: Callable, params=None, *,
return fun
# We replace <flat index 0> with 0
_re_clean_keystr_arg_names = re.compile(r"<flat index ([^>]+)>")
def _clean_keystr_arg_names(k: KeyPath) -> str:
res = keystr(k)
return _re_clean_keystr_arg_names.sub(r"\1", res)
@transformation_with_aux2
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
ans = _fun(*args, **kwargs)
result_paths = [keystr(path) for path, _ in generate_key_paths(ans)]
result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)]
if _store:
# In some instances a lu.WrappedFun is called multiple times, e.g.,
# the bwd function in a custom_vjp

View File

@ -3044,11 +3044,11 @@ def _custom_jvp_call_lowering_rule(
ctx: LoweringRuleContext,
*args,
call_jaxpr: jax_core.Jaxpr,
jvp_jaxpr_thunk: Callable,
jvp_jaxpr_fun: lu.WrappedFun,
num_consts: int,
symbolic_zeros: bool,
):
del jvp_jaxpr_thunk
del jvp_jaxpr_fun
if symbolic_zeros: raise NotImplementedError
if num_consts: raise NotImplementedError
if call_jaxpr.consts: raise NotImplementedError

View File

@ -242,7 +242,9 @@ def _batch_block_mapping(
with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
lu.wrap_init(_block_map_function,
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
idx_avals)
shape = block_mapping.block_shape
if dim is batching.not_mapped:
new_block_shape = shape

View File

@ -73,7 +73,7 @@ def hoist_consts_to_refs(
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr

View File

@ -14,7 +14,7 @@
from __future__ import annotations
from typing import Any
from typing import Any, Callable
from jax._src import core
from jax._src import source_info_util
@ -90,7 +90,9 @@ def jvp(f, primals, tangents, attr_tangents):
primals_flat, in_tree = tree_flatten((attr_primals, *primals))
tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents))
if in_tree != in_tree_: raise Exception
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), in_tree)
dbg = api_util.debug_info("attrs_jvp", f, primals, {})
f_, out_tree = flatten_fun_nokwargs(
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree)
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
primals_flat, tangents_flat)
out_primals = tree_unflatten(out_tree(), out_primals_flat)
@ -151,12 +153,14 @@ ad.JVPTrace.process_getattr = _getattr_jvp
ad.LinearizeTrace.process_setattr = _setattr_jvp
ad.LinearizeTrace.process_getattr = _getattr_jvp
def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []):
attr_primals = [jax_getattr(o, a) for o, a in attrs]
attr_avals = [core.get_aval(p) for p in attr_primals]
primals_flat, in_tree = tree_flatten(primals)
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
dbg = api_util.debug_info("attrs linearize", f, primals, {})
f_, out_tree = flatten_fun_nokwargs(
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree)
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
f_, *attr_primals, *primals_flat)
f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
@ -206,7 +210,9 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
attr_primals = [jax_getattr(o, a) for o, a in attrs]
primals_flat, in_tree = tree_flatten(primals)
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
dbg = api_util.debug_info("attrs vjp", f, primals, {})
f_, out_tree = flatten_fun_nokwargs(
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree)
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
f_, *attr_primals, *primals_flat)
attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval()

View File

@ -599,6 +599,8 @@ class GraphSerializationImpl(SerializationImpl):
name_stack = util.wrap_name(fun_name, "jax2tf")
self.name_stack = name_stack
self.args_flat_tf = args_flat_tf
self.debug = api_util.debug_info("jax2tf", fun_jax,
args_specs, kwargs_specs)
def before_conversion(self):
prev_enable_xla = _thread_local_state.enable_xla
@ -623,7 +625,10 @@ class GraphSerializationImpl(SerializationImpl):
dim_values, _ = _interpret_fun_jax(
partial(shape_poly.compute_dim_vars_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
self.args_flat_tf, self.args_avals_flat, self.name_stack,
debug_info=api_util.debug_info("jax2tf dim_vars",
shape_poly.compute_dim_vars_from_arg_shapes,
self.args_specs, self.kwargs_specs))
_thread_local_state.shape_env = zip(dim_vars, dim_values)
@ -639,7 +644,8 @@ class GraphSerializationImpl(SerializationImpl):
fun_flat_jax,
args_flat_tf, self.args_avals_flat,
self.name_stack,
fresh_constant_cache=True)
fresh_constant_cache=True,
debug_info=self.debug)
return outs_tf, self.outs_avals, out_tree_thunk()
def get_vjp_fun(self) -> tuple[Callable,
@ -849,10 +855,12 @@ def _interpret_fun_jax(
fun_jax: Callable,
args_tf: Sequence[TfVal],
args_avals: Sequence[core.ShapedArray],
extra_name_stack: str | None,
extra_name_stack: str | None, *,
fresh_constant_cache: bool = False,
debug_info: core.DebugInfo,
) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]:
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), args_avals)
subtrace_fun = _interpret_subtrace(
lu.wrap_init(fun_jax, debug_info=debug_info), args_avals)
with _extended_name_stack(extra_name_stack):
out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
@ -1033,7 +1041,9 @@ def _convert_jax_impl(impl_jax: Callable, *,
results_tf, _ = _interpret_fun_jax(
impl_multiple_results_jax, args_tf, _in_avals,
extra_name_stack)
extra_name_stack,
debug_info=api_util.debug_info("jax2tf", impl_jax,
args_tf, kwargs))
return results_tf if multiple_results else results_tf[0]
return wrapped_tf
@ -1066,7 +1076,8 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
"""
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
args_tf, jaxpr.in_avals, extra_name_stack,
fresh_constant_cache=fresh_constant_cache)
fresh_constant_cache=fresh_constant_cache,
debug_info=jaxpr.jaxpr.debug_info)
return outs_tf
@ -1197,7 +1208,9 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
shape_values_tf, _ = _interpret_fun_jax(
partial(core.evaluate_shape, shape, dim_vars),
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
dim_values, [core.dim_value_aval()] * len(dim_values), "", # type: ignore
debug_info=api_util.debug_info("jax2tf evaluate_shape", core.evaluate_shape,
(0, 0, *dim_values), {}))
# Keep only the non-constant dimensions
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf # type: ignore
for d, d_tf in zip(shape, shape_values_tf))
@ -3431,10 +3444,10 @@ def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):
tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve
def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable,
jvp_jaxpr_fun: Callable,
num_consts: int) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
del jvp_jaxpr_thunk, num_consts
del jvp_jaxpr_fun, num_consts
return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp",
fresh_constant_cache=False)

View File

@ -38,6 +38,7 @@ from jax import lax
import jax.numpy as jnp
from jax import random
from jax import tree_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
@ -442,7 +443,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
partial(shape_poly.compute_dim_vars_from_arg_shapes,
avals,
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
args_tf, avals, "")
args_tf, avals, "",
debug_info=api_util.debug_info("jax2tf dim_vars",
shape_poly.compute_dim_vars_from_arg_shapes,
avals, {}))
if expected_shapes is not None:
expected_avals = tree_util.tree_map(
lambda shape_str: core.ShapedArray(

View File

@ -1387,7 +1387,7 @@ def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
@register_check(custom_derivatives.custom_jvp_call_p)
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_thunk,
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun,
num_consts, symbolic_zeros):
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
@ -1578,11 +1578,13 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_linearize(trace, shard_map_p, f, tracers, mesh, in_names,
def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto):
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in,
f.debug_info)
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
all_names = _all_mesh_names_except_spmd(mesh, trace)
@ -1780,23 +1782,24 @@ def _partial_eval_jaxpr_custom_rule(
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
_partial_eval_jaxpr_custom_rule
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
def _add_reshapes(which: Sequence[bool],
jaxpr_known: core.Jaxpr,
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
# add singleton axes to residuals which are from jaxpr_known and are scalars
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
if not any(which_): return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
@lu.wrap_init
def known(*args):
out = core.eval_jaxpr(jaxpr_known, (), *args)
out_known, res = split_list(out, [len(out) - sum(which)])
res = [_add_singleton(x) if not x.shape else x for x in res]
return [*out_known, *res]
avals_in = [v.aval for v in jaxpr_known.invars]
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(known, avals_in)
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in)
@lu.wrap_init
def staged(*args):
res_, ins = split_list(args, [len(which)])
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
@ -1804,7 +1807,8 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in)
return jaxpr_known, jaxpr_staged
@ -2070,7 +2074,8 @@ def _replication_rewrite_match(
in_rep: Sequence[set[AxisName]],
out_rep_dst: Sequence[set[AxisName]],
) -> core.ClosedJaxpr:
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
debug_info=jaxpr.jaxpr.debug_info)
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
f = _match_rep(f, mesh, out_rep, out_rep_dst)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)

View File

@ -1042,8 +1042,14 @@ def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
@bcoo_dot_general_sampled_p.def_abstract_eval
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result)
dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval",
lax.dot_general, (A, B), dict(dimension_numbers=dimension_numbers))
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B,
debug_info=dbg)
dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval",
_bcoo_extract, (indices, dense_result), {})
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result,
debug_info=dbg)
return sparse_result
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):

View File

@ -55,6 +55,7 @@ import numpy as np
import jax
from jax import lax
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src.custom_derivatives import lift_jvp
@ -365,12 +366,15 @@ def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]):
spenv = SparsifyEnv(out_bufs)
return spvalues_to_arrays(spenv, out_spvalues())
def _sparsify_with_tracer(fun):
def _sparsify_with_tracer(fun: Callable):
"""Implementation of sparsify() using tracers."""
@functools.wraps(fun)
def _wrapped(*args):
args_flat, in_tree = tree_flatten(args, is_leaf=_is_sparse_obj)
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
wrapped_fun, out_tree = flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=api_util.debug_info("sparsify", fun, args, {})),
in_tree)
out = sparsify_fun(wrapped_fun, args_flat)
return tree_unflatten(out_tree(), out)
return _wrapped
@ -439,7 +443,12 @@ def sparsify_raw(f):
) -> tuple[Sequence[SparsifyValue], pytree.PyTreeDef]:
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
wrapped_fun, out_tree = flatten_fun_nokwargs(
lu.wrap_init(
f, params,
debug_info=api_util.debug_info("sparsify", f,
spvalues_to_arrays(spenv, spvalues), {})),
in_tree)
jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
if len(out_avals_flat) != len(result):
@ -716,14 +725,14 @@ def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_ind
sparse_rules_bcoo[lax.gather_p] = _gather_sparse_rule
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
def _sparsify_jaxpr(spenv: SparsifyEnv,
jaxpr: core.ClosedJaxpr, *spvalues):
# TODO(jakevdp): currently this approach discards all information about
# shared data & indices when generating the sparsified jaxpr. The
# current approach produces valid sparsified while loops, but they
# don't work in corner cases (see associated TODO in sparsify_test.py)
out_tree: pytree.PyTreeDef | None = None
@lu.wrap_init
def wrapped(*args_flat):
# TODO(frostig,jakevdp): This closes over `spenv`, which can bring
# in buffers from the "outer scope" as constants. Is this a
@ -740,7 +749,8 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
args = spvalues_to_arrays(spenv, spvalues)
args_flat, in_tree = tree_flatten(args)
avals_flat = [core.get_aval(arg) for arg in args_flat]
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(wrapped, debug_info=jaxpr.jaxpr.debug_info), avals_flat)
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
assert out_tree is not None
return sp_jaxpr, out_tree
@ -866,17 +876,18 @@ sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
call_jaxpr: core.ClosedJaxpr = params.pop('call_jaxpr')
jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk')
num_consts = params.pop('num_consts')
jvp_jaxpr_fun: lu.WrappedFun = params.pop('jvp_jaxpr_fun')
num_consts: int = params.pop('num_consts')
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues)
@lu.wrap_init
def fun(*arrs):
sparrs = arrays_to_spvalues(spenv, arrs)
out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv)
return spvalues_to_arrays(spenv, out)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
invals = spvalues_to_arrays(spenv, spvalues)
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params)
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(
lu.wrap_init(fun, debug_info=call_jaxpr.jaxpr.debug_info),
jvp, *invals, **params)
return arrays_to_spvalues(spenv, outvals)
sparse_rules_bcoo[jax.custom_derivatives.custom_jvp_call_p] = _custom_jvp_sparse_rule

View File

@ -26,9 +26,8 @@ import jax
from jax import lax
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs
from jax.api_util import flatten_fun_nokwargs, debug_info
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
@ -48,14 +47,16 @@ def call(f, *args):
@util.curry
def core_call(f, *args):
args, in_tree = jax.tree.flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
dbg = debug_info("core_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
out = core.call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
@util.curry
def core_closed_call(f, *args):
args, in_tree = jax.tree.flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
dbg = debug_info("core_closed_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
out = core.closed_call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
@ -362,7 +363,10 @@ class CoreTest(jtu.JaxTestCase):
aval = core.ShapedArray((), jnp.dtype('int32'))
pval = pe.PartialVal.unknown(aval)
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(lu.wrap_init(f), [pval], False)
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(
lu.wrap_init(f,
debug_info=debug_info("test", f, (0,), {})),
[pval], False)
dropvar, b = jaxpr.eqns[0].outvars
self.assertEqual(dropvar.aval, aval)
@ -548,12 +552,13 @@ class DynamicShapesTest(jtu.JaxTestCase):
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
return x, y
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 3)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
@ -569,7 +574,6 @@ class DynamicShapesTest(jtu.JaxTestCase):
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
@jax.jit
def g(x, y, z, w):
@ -577,7 +581,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
return g(x, y, x, y)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
lu.wrap_init(f,
debug_info=debug_info("test", f, (0, 1), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
@ -605,7 +611,6 @@ class DynamicShapesTest(jtu.JaxTestCase):
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
@jax.jit
def g(_, x, y, z, w):
@ -613,7 +618,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
return g(x.shape[0], x, y, x, y)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
# d:f32[a] e:f32[a] = xla_call[
@ -641,7 +648,6 @@ class DynamicShapesTest(jtu.JaxTestCase):
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
z = lax.mul(x, y)
w = lax.sin(z)
@ -649,7 +655,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
return (u,)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertLen(jaxpr.eqns, 3)
@ -667,14 +675,15 @@ class DynamicShapesTest(jtu.JaxTestCase):
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(a, b):
@jax.jit
def g(x): return x
return g(a),
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, m, a, b], keep_inputs=[False, False, True, True])
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, m, a, b], keep_inputs=[False, False, True, True])
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }

View File

@ -901,7 +901,7 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): what are these flat_index components?
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][<flat index 0>][0][<flat index 0>][0][0]",
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]",
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
],
@ -1091,8 +1091,7 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=",
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=",
# TODO(necula): flat_index?
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[<flat index 0>][0][0],[<flat index 0>][0][1]",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0][0][0],[0][0][1]",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[

View File

@ -29,6 +29,7 @@ import numpy as np
import jax
import jax.ad_checkpoint
from jax import api_util
from jax import lax
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
@ -1329,7 +1330,11 @@ class ShardMapTest(jtu.JaxTestCase):
def test_rewrite_process_call(self):
def f(x):
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
return core.call_p.bind(
lu.wrap_init(lambda x: [2. * x],
debug_info=api_util.debug_info("test", lambda x: [2. * x],
(x,), {})),
x)[0] * x
mesh = jtu.create_mesh((4,), ('x',))
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
@ -1345,7 +1350,10 @@ class ShardMapTest(jtu.JaxTestCase):
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
def f(x):
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
return core.call_p.bind(
lu.wrap_init(lambda: [2. * x],
debug_info=api_util.debug_info("test", lambda: [2. * x],
(), {})))[0] * x
x = jnp.arange(4.)
y = f(x)