[better_errors] Continue adding debug info to Jaxprs (step 7)

This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
This commit is contained in:
George Necula 2025-02-08 15:19:46 +02:00
parent 8401d9b752
commit 817b3e5757
20 changed files with 172 additions and 87 deletions

View File

@ -610,10 +610,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
except (ValueError, TypeError): except (ValueError, TypeError):
return None return None
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable): def save_wrapped_fun_sourceinfo(wrapper: Callable,
wrapped: Callable | core.DebugInfo | None) -> None:
# Prefer this to functools.wraps because it does not create a reference to # Prefer this to functools.wraps because it does not create a reference to
# the wrapped function. # the wrapped function.
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped)) if isinstance(wrapped, core.DebugInfo):
func_src_info = wrapped.func_src_info
elif callable(wrapped):
func_src_info = fun_sourceinfo(wrapped)
else:
return
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)") _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")

View File

@ -1823,7 +1823,7 @@ def _move_mutable_consts(
invars = (*jaxpr.invars, *mutvars) invars = (*jaxpr.invars, *mutvars)
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
effects, None) effects, closed_jaxpr.jaxpr.debug_info)
return core.ClosedJaxpr(jaxpr, consts), in_mut return core.ClosedJaxpr(jaxpr, consts), in_mut
@weakref_lru_cache @weakref_lru_cache

View File

@ -1062,7 +1062,11 @@ def core_map(
""" """
def wrapped(f): def wrapped(f):
flat_args, in_tree = tree_util.tree_flatten(((), {})) flat_args, in_tree = tree_util.tree_flatten(((), {}))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f,
debug_info=api_util.debug_info("pallas_core_map", f,
(), {})),
in_tree)
with jax_core.extend_axis_env_nd(mesh.shape.items()): with jax_core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh,

View File

@ -91,7 +91,10 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
""" """
flattened_args, treedef = jax.tree.flatten(args) flattened_args, treedef = jax.tree.flatten(args)
partial_fun = functools.partial(fun, **kwargs) partial_fun = functools.partial(fun, **kwargs)
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun), wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(partial_fun,
debug_info=api_util.debug_info("cost_estimate", fun,
args, kwargs)),
treedef) treedef)
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)

View File

@ -476,7 +476,9 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args, return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
compiler_params=compiler_params) compiler_params=compiler_params)
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals) new_jaxpr = _to_jaxpr(lu.wrap_init(f,
debug_info=pjit_jaxpr.jaxpr.debug_info),
in_avals)
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
elif prim is primitives.run_scoped_p: elif prim is primitives.run_scoped_p:

View File

@ -23,6 +23,7 @@ import string
from typing import Any, Hashable from typing import Any, Hashable
import jax import jax
from jax import api_util
from jax import lax from jax import lax
from jax import tree_util from jax import tree_util
from jax._src import ad_util from jax._src import ad_util
@ -845,7 +846,10 @@ def lower_jaxpr_to_func(
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable: def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params): def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params) wrapped_fun = lu.wrap_init(
f, params,
debug_info=api_util.debug_info("mosaic lower_fun", f,
args, params))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
if consts: if consts:
raise NotImplementedError raise NotImplementedError

View File

@ -24,6 +24,7 @@ import math
from typing import Any from typing import Any
import jax import jax
from jax import api_util
from jax import lax from jax import lax
from jax._src import core from jax._src import core
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -94,7 +95,12 @@ def _uses_arguments(
index_map: Callable[..., Any], num_args: int index_map: Callable[..., Any], num_args: int
) -> Sequence[bool]: ) -> Sequence[bool]:
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args lu.wrap_init(
index_map,
debug_info=api_util.debug_info("pallas index_map",
index_map,
(0,) * num_args, {})),
(core.ShapedArray((), jnp.int32),) * num_args
) )
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars)) _, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
return used_inputs return used_inputs

View File

@ -1001,7 +1001,10 @@ def pallas_call_checkify_oob_grid(error: checkify.Error,
) )
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),)) flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
wrapped_loop, _ = api_util.flatten_fun_nokwargs( wrapped_loop, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(f), jaxpr_in_tree) lu.wrap_init(f,
debug_info=api_util.debug_info("checkify oob_grid_access",
f, (0,), {})),
jaxpr_in_tree)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()): with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
avals_in = map(jax_core.get_aval, flat_args) avals_in = map(jax_core.get_aval, flat_args)
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic( traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
@ -1426,7 +1429,7 @@ def _pallas_call_state_discharge_rule(
) )
) )
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic( new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_rewritten_body), lu.wrap_init(_rewritten_body, debug_info=jaxpr.debug_info),
[ [
*index_map_avals, *index_map_avals,
*ref_avals, *ref_avals,

View File

@ -848,7 +848,11 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
types in addition to :class:`jax.experimental.pallas.MemoryRef`. types in addition to :class:`jax.experimental.pallas.MemoryRef`.
""" """
flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f,
debug_info=api_util.debug_info("pallas run_scoped",
f, types, kw_types)),
in_tree)
# We allow ref avals to be transformed references. # We allow ref avals to be transformed references.
ref_avals = [t.get_ref_aval() for t in flat_types] ref_avals = [t.get_ref_aval() for t in flat_types]
avals = [ avals = [

View File

@ -415,7 +415,10 @@ def lower_fun(
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
def f_lowered(ctx: LoweringRuleContext, *args, **params): def f_lowered(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params) wrapped_fun = lu.wrap_init(
fn, params,
debug_info=api_util.debug_info("pallas triton lower_fun", fun,
args, params))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
@ -539,7 +542,11 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes):
] ]
in_tree = tree_util.tree_structure((args, args)) in_tree = tree_util.tree_structure((args, args))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree lu.wrap_init(
body,
debug_info=api_util.debug_info("pallas triton associative_scan",
body, (args, args), {})),
in_tree
) )
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, in_avals flat_fun, in_avals
@ -2185,7 +2192,11 @@ def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
mapped_avals = [jax_core.ShapedArray((), aval.dtype) for aval in ctx.avals_in] mapped_avals = [jax_core.ShapedArray((), aval.dtype) for aval in ctx.avals_in]
in_tree = tree_util.tree_structure((a, a)) in_tree = tree_util.tree_structure((a, a))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree lu.wrap_init(
body,
debug_info=api_util.debug_info("pallas triton reduction",
body, (a, a), {})),
in_tree
) )
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, [*mapped_avals, *mapped_avals] flat_fun, [*mapped_avals, *mapped_avals]

View File

@ -460,12 +460,13 @@ def _addupdate_discharge(x, val, idx, tree):
return _prepend_scatter(x, indexer, val, add=True) return _prepend_scatter(x, indexer, val, add=True)
@weakref_lru_cache @weakref_lru_cache
def _cached_closed_jaxpr_discharge(closed_jaxpr): def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr):
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts
num_outs = len(jaxpr.outvars) num_outs = len(jaxpr.outvars)
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr)) fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr),
debug_info=discharged_jaxpr.debug_info)
return discharged_closed_jaxpr, num_outs, fun return discharged_closed_jaxpr, num_outs, fun
@register_discharge_rule(core.closed_call_p) @register_discharge_rule(core.closed_call_p)
@ -598,7 +599,6 @@ def _convert_outputs_to_writes(
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
in_avals = [v.aval for v in jaxpr.invars] in_avals = [v.aval for v in jaxpr.invars]
@lu.wrap_init
def eval_jaxpr(*refs): def eval_jaxpr(*refs):
# We split the refs into the original input refs and the dummy residual # We split the refs into the original input refs and the dummy residual
# refs. # refs.
@ -610,14 +610,15 @@ def _convert_outputs_to_writes(
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef) res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
else v.aval for v in jaxpr.outvars] else v.aval for v in jaxpr.outvars]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals]) lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info),
[*in_avals, *res_ref_avals])
assert not consts assert not consts
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars" assert not jaxpr.constvars, "Jaxpr should not have constvars"
@lu.wrap_init
def eval_jaxpr(*refs): def eval_jaxpr(*refs):
residual_refs, orig_refs = split_list(refs, [num_res]) residual_refs, orig_refs = split_list(refs, [num_res])
residual_vals = [r[...] for r in residual_refs] residual_vals = [r[...] for r in residual_refs]
@ -629,7 +630,9 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
aval for aval in res_val_avals] aval for aval in res_val_avals]
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*res_ref_avals, *orig_ref_avals]) lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info),
[*res_ref_avals, *orig_ref_avals])
return jaxpr return jaxpr
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
@ -845,11 +848,12 @@ def _run_state_partial_eval_custom(
*[v.aval for v in res_staged_invars], **staged_params) *[v.aval for v in res_staged_invars], **staged_params)
_, staged_outvars = partition_list(in_unknowns, eqn.outvars) _, staged_outvars = partition_list(in_unknowns, eqn.outvars)
if num_res: if num_res:
@lu.wrap_init
def staged(*args): def staged(*args):
out = run_state_p.bind(*args, **staged_params) out = run_state_p.bind(*args, **staged_params)
return out[num_res:] return out[num_res:]
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged, staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info),
[v.aval for v in res_staged_invars]) [v.aval for v in res_staged_invars])
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars, eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
staged_outvars, staged_outvars,
@ -918,7 +922,9 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool],
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ()) ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
return [] return []
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic( jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) lu.wrap_init(trans,
debug_info=jaxpr.debug_info),
[v.aval for v in jaxpr.invars])
return jaxpr_trans, consts return jaxpr_trans, consts
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,

View File

@ -19,6 +19,7 @@ from jax._src.core import (
from jax._src.api_util import ( from jax._src.api_util import (
argnums_partial as argnums_partial, argnums_partial as argnums_partial,
debug_info as debug_info,
donation_vector as donation_vector, donation_vector as donation_vector,
flatten_axes as flatten_axes, flatten_axes as flatten_axes,
flatten_fun as flatten_fun, flatten_fun as flatten_fun,

View File

@ -60,6 +60,7 @@ from functools import partial
import numpy as np import numpy as np
from jax import lax from jax import lax
from jax import api_util
import jax.numpy as jnp import jax.numpy as jnp
from jax.experimental import pjit from jax.experimental import pjit
from jax.tree_util import (register_pytree_node, tree_structure, from jax.tree_util import (register_pytree_node, tree_structure,
@ -147,7 +148,9 @@ def jet(fun, primals, series):
store.store(tree) store.store(tree)
return ans return ans
f, out_tree = flatten_fun_output(lu.wrap_init(fun)) f, out_tree = flatten_fun_output(
lu.wrap_init(fun,
debug_info=api_util.debug_info("jet", fun, primals, {})))
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
@ -723,7 +726,7 @@ jet_rules[lax.scatter_add_p] = _scatter_add_rule
def _jet_jaxpr( def _jet_jaxpr(
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
) -> tuple[core.ClosedJaxpr, Any]: ) -> tuple[core.ClosedJaxpr, Any]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def) f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic( jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic(
f_jet, primals_and_series_avals) f_jet, primals_and_series_avals)

View File

@ -397,7 +397,10 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature: def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
args_flat, in_tree = tree_util.tree_flatten(args) args_flat, in_tree = tree_util.tree_flatten(args)
in_avals_flat = [core.get_aval(arg) for arg in args_flat] in_avals_flat = [core.get_aval(arg) for arg in args_flat]
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=api_util.debug_info("key_reuse", fun, args, {})),
in_tree)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
return jaxpr_type_signature(jaxpr) return jaxpr_type_signature(jaxpr)

View File

@ -28,8 +28,10 @@ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
from functools import partial from functools import partial
import operator as op import operator as op
from typing import Callable
import jax import jax
from jax import api_util
import jax.numpy as jnp import jax.numpy as jnp
from jax._src import core from jax._src import core
from jax import custom_derivatives from jax import custom_derivatives
@ -44,8 +46,9 @@ map = safe_map
zip = safe_zip zip = safe_zip
def ravel_first_arg(f, unravel): def ravel_first_arg(f: Callable, unravel, debug_info: core.DebugInfo):
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped return ravel_first_arg_(lu.wrap_init(f, debug_info=debug_info),
unravel).call_wrapped
@lu.transformation2 @lu.transformation2
def ravel_first_arg_(f, unravel, y_flat, *args): def ravel_first_arg_(f, unravel, y_flat, *args):
@ -179,9 +182,10 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jn
return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts) return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts)
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4)) @partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args): def _odeint_wrapper(func: Callable, rtol, atol, mxstep, hmax, y0, ts, *args):
y0, unravel = ravel_pytree(y0) y0, unravel = ravel_pytree(y0)
func = ravel_first_arg(func, unravel) debug = api_util.debug_info("odeint", func, args, {})
func = ravel_first_arg(func, unravel, debug)
out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args) out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
return jax.vmap(unravel)(out) return jax.vmap(unravel)(out)

View File

@ -17,6 +17,7 @@ import unittest
from absl.testing import absltest from absl.testing import absltest
import jax import jax
from jax import api_util
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax from jax import lax
from jax.experimental import pjit from jax.experimental import pjit
@ -178,12 +179,13 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
def test_core_call_primitive_inherits_effects(self): def test_core_call_primitive_inherits_effects(self):
def f(x): def f(x):
@lu.wrap_init
def f_(x): def f_(x):
effect_p.bind(effect=foo_effect) effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect) effect_p.bind(effect=bar_effect)
return [x] return [x]
return core.call(f_, x)[0] dbg = api_util.debug_info("test", f_, (2.,), {})
return core.call(
lu.wrap_init(f_, debug_info=dbg), x)[0]
jaxpr = jax.make_jaxpr(f)(2.) jaxpr = jax.make_jaxpr(f)(2.)
self.assertIn(foo_effect, jaxpr.jaxpr.effects) self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects) self.assertIn(bar_effect, jaxpr.jaxpr.effects)

View File

@ -15,6 +15,7 @@ import functools
from absl.testing import absltest from absl.testing import absltest
import jax import jax
from jax import api_util
import jax.numpy as jnp import jax.numpy as jnp
from jax._src import core from jax._src import core
from jax import lax from jax import lax
@ -85,11 +86,12 @@ class NameStackTest(jtu.JaxTestCase):
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self): def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
@jax.named_scope('foo') @jax.named_scope('foo')
def f(x): def f(x):
@lu.wrap_init
@jax.named_scope('bar') @jax.named_scope('bar')
def _f(x): def _f(x):
return [x + 1] return [x + 1]
return core.call(_f, x)[0] return core.call(lu.wrap_init(
_f,
debug_info=api_util.debug_info("test", _f, (0,), {})), x)[0]
jaxpr = jax.make_jaxpr(f)(2).jaxpr jaxpr = jax.make_jaxpr(f)(2).jaxpr
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')

View File

@ -19,12 +19,13 @@ import functools
import itertools import itertools
import math import math
import sys import sys
from typing import Any from typing import Any, Callable
import unittest import unittest
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import jax import jax
from jax import api_util
from jax import lax from jax import lax
from jax import random from jax import random
from jax._src import dtypes from jax._src import dtypes
@ -61,6 +62,11 @@ jtu.setup_hypothesis(max_examples=50)
intx = dtypes.canonicalize_dtype(jnp.int64) intx = dtypes.canonicalize_dtype(jnp.int64)
floatx = dtypes.canonicalize_dtype(jnp.float64) floatx = dtypes.canonicalize_dtype(jnp.float64)
def wrap_init(f: Callable, nr_args: int):
# wrapper for lu.wrap_init with debugging info
return lu.wrap_init(
f,
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
def is_power_of_two(n: int) -> bool: def is_power_of_two(n: int) -> bool:
return (n > 0) and (n & (n - 1) == 0) return (n > 0) and (n & (n - 1) == 0)
@ -2284,7 +2290,7 @@ class PallasPrimitivesTest(PallasBaseTest):
x = pl.load(x_ref, expr()) x = pl.load(x_ref, expr())
return [x] return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
self.assertIn(expected, jaxpr.pretty_print(use_color=False)) self.assertIn(expected, jaxpr.pretty_print(use_color=False))
@parameterized.parameters(*[ @parameterized.parameters(*[
@ -2299,7 +2305,7 @@ class PallasPrimitivesTest(PallasBaseTest):
pl.store(x_ref, expr(), pl.load(x_ref, expr())) pl.store(x_ref, expr(), pl.load(x_ref, expr()))
return [] return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
self.assertIn(expected, jaxpr.pretty_print(use_color=False)) self.assertIn(expected, jaxpr.pretty_print(use_color=False))
@parameterized.parameters(*[ @parameterized.parameters(*[
@ -2319,7 +2325,7 @@ class PallasPrimitivesTest(PallasBaseTest):
x = pl.swap(x_ref, expr(), pl.load(x_ref, expr())) x = pl.swap(x_ref, expr(), pl.load(x_ref, expr()))
return [x] return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
self.assertIn(expected, jaxpr.pretty_print(use_color=False)) self.assertIn(expected, jaxpr.pretty_print(use_color=False))

View File

@ -23,6 +23,7 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import jax import jax
from jax import api_util
from jax import random from jax import random
from jax import lax from jax import lax
from jax._src import core from jax._src import core
@ -54,6 +55,12 @@ from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
config.parse_flags_with_absl() config.parse_flags_with_absl()
jtu.setup_hypothesis() jtu.setup_hypothesis()
def wrap_init(f: Callable, nr_args: int):
# wrapper for lu.wrap_init with debugging info
return lu.wrap_init(
f,
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
class StatePrimitivesTest(jtu.JaxTestCase): class StatePrimitivesTest(jtu.JaxTestCase):
@ -62,7 +69,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def f(x_ref): def f(x_ref):
return [ref_get(x_ref, ())] return [ref_get(x_ref, ())]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval]) pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name="trivial_get", ref_shape=(1, 2), dict(testcase_name="trivial_get", ref_shape=(1, 2),
@ -121,10 +128,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return [out] return [out]
if should_error: if should_error:
with self.assertRaises(Exception): with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval]) pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
else: else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval]) wrap_init(f, 1), [ref_aval])
self.assertSetEqual(jaxpr.effects, self.assertSetEqual(jaxpr.effects,
{ReadEffect(len(jaxpr.constvars))}) {ReadEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 1) self.assertLen(out_avals, 1)
@ -139,7 +146,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def f(x_ref, val): def f(x_ref, val):
return [ref_swap(x_ref, (), val)] return [ref_swap(x_ref, (), val)]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2), dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
@ -218,10 +225,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return [out] return [out]
if should_error: if should_error:
with self.assertRaises(Exception): with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
else: else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval]) wrap_init(f, 2), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects, self.assertSetEqual(jaxpr.effects,
{WriteEffect(len(jaxpr.constvars))}) {WriteEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 1) self.assertLen(out_avals, 1)
@ -295,10 +302,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return [] return []
if should_error: if should_error:
with self.assertRaises(Exception): with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
else: else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval]) wrap_init(f, 2), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects, self.assertSetEqual(jaxpr.effects,
{AccumEffect(len(jaxpr.constvars))}) {AccumEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 0) self.assertLen(out_avals, 0)
@ -309,7 +316,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def f(x_ref, val): def f(x_ref, val):
return [ref_addupdate(x_ref, (), val)] return [ref_addupdate(x_ref, (), val)]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
def test_can_represent_get_and_swap_in_jaxprs(self): def test_can_represent_get_and_swap_in_jaxprs(self):
@ -318,7 +325,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x[()] = jnp.int32(2) x[()] = jnp.int32(2)
return (x[()],) return (x[()],)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0) self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, swap_p) self.assertEqual(jaxpr.eqns[0].primitive, swap_p)
@ -331,7 +338,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_addupdate(x, (), jnp.int32(1)) ref_addupdate(x, (), jnp.int32(1))
return (x[()],) return (x[()],)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0) self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
@ -341,14 +348,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x = x_ref[()] x = x_ref[()]
return [x] return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def body(x_ref): def body(x_ref):
x = x_ref[:, 0] x = x_ref[:, 0]
return [x] return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((1, 2), jnp.int32)])
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False)) self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self): def test_set_custom_pretty_printing_rule(self):
@ -356,14 +363,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x_ref[()] = jnp.int32(2) x_ref[()] = jnp.int32(2)
return [] return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val): def body(x_ref, val):
x_ref[:, 0] = val x_ref[:, 0] = val
return [] return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)]) core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False)) self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
@ -372,14 +379,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x = ref_swap(x_ref, (), jnp.int32(2)) x = ref_swap(x_ref, (), jnp.int32(2))
return [x] return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val): def body(x_ref, val):
x = ref_swap(x_ref, (slice(None), 0), val) x = ref_swap(x_ref, (slice(None), 0), val)
return [x] return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)]) core.ShapedArray((1,), jnp.int32)])
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b", self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
jaxpr.pretty_print(use_color=False)) jaxpr.pretty_print(use_color=False))
@ -389,7 +396,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_addupdate(x_ref, (), jnp.int32(2)) ref_addupdate(x_ref, (), jnp.int32(2))
return [] return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
@ -397,7 +404,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_addupdate(x_ref, (slice(None), 0), val) ref_addupdate(x_ref, (slice(None), 0), val)
return [] return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)]) core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False)) self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
@ -413,7 +420,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')), in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))] shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p)
@ -429,7 +436,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')), in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))] shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
@ -449,7 +456,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')), in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))] shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, get_p) self.assertEqual(jaxpr.eqns[2].primitive, get_p)
@ -521,12 +528,12 @@ class StatePrimitivesTest(jtu.JaxTestCase):
# discharge-of-vmap # discharge-of-vmap
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals]) wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
# vmap-of-discharge # vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals]) wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims), in_axes=(ref_bdim, *idx_bdims),
@ -544,7 +551,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = ref_get(a_ref, ()) a = ref_get(a_ref, ())
return [a + 1] return [a + 1]
in_avals = [shaped_array_ref((), jnp.dtype('float32'))] in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
# Discharging should just turn this into a jaxpr that just adds 1. # Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
@ -560,7 +567,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = ref_get(a_ref, (0, 1)) a = ref_get(a_ref, (0, 1))
return [a + 1] return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
# Discharging should just turn this into a jaxpr that just adds 1. # Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
@ -580,7 +587,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [a + 1] return [a + 1]
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals) wrap_init(f, 1), in_avals)
discharged_jaxpr, discharged_consts = discharge_state( discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts) stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3)) inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
@ -594,7 +601,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [] return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')), in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))] core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
# Discharging should just turn this into a jaxpr that ignores the first # Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1. # value and returns second value plus 1.
@ -611,7 +618,7 @@ class StateDischargeTest(jtu.JaxTestCase):
ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return [] return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
# Discharging should just turn this into a jaxpr that just adds 1. # Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
@ -631,7 +638,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32') a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
return [] return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
discharged_jaxpr, discharged_consts = discharge_state( discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts) stateful_jaxpr, consts)
@ -648,7 +655,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [a + 1] return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)] in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals) wrap_init(f, 1), in_avals)
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.invars, 1)
@ -665,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [] return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')), in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))] core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals) in_avals)
# Discharging should just turn this into a jaxpr that adds the first value, # Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1. # second value, and 1.
@ -683,7 +690,7 @@ class StateDischargeTest(jtu.JaxTestCase):
jnp.ones(2, dtype=jnp.dtype('float32'))) jnp.ones(2, dtype=jnp.dtype('float32')))
return [] return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.invars, 1)
@ -704,7 +711,7 @@ class StateDischargeTest(jtu.JaxTestCase):
jnp.ones((2, 3), 'float32')) jnp.ones((2, 3), 'float32'))
return [] return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
discharged_jaxpr, discharged_consts = discharge_state( discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts) stateful_jaxpr, consts)
@ -718,7 +725,7 @@ class StateDischargeTest(jtu.JaxTestCase):
b = a + 1 b = a + 1
return [a, b] return [a, b]
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))] in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals) in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.invars, 1)
@ -738,7 +745,7 @@ class StateDischargeTest(jtu.JaxTestCase):
shaped_array_ref((4,), jnp.dtype('float32')), shaped_array_ref((4,), jnp.dtype('float32')),
shaped_array_ref((4,), jnp.dtype('float32')) shaped_array_ref((4,), jnp.dtype('float32'))
] ]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals) in_avals)
discharged_jaxpr, _ = discharge_state( discharged_jaxpr, _ = discharge_state(
stateful_jaxpr, consts, should_discharge=[False, True]) stateful_jaxpr, consts, should_discharge=[False, True])
@ -758,7 +765,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [] return []
in_avals = [shaped_array_ref((), jnp.float32)] in_avals = [shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals)
def test_partial_discharge(self): def test_partial_discharge(self):
def f(a_ref, b_ref): def f(a_ref, b_ref):
@ -769,7 +776,7 @@ class StateDischargeTest(jtu.JaxTestCase):
scalar_ref_1 = shaped_array_ref((), jnp.float32) scalar_ref_1 = shaped_array_ref((), jnp.float32)
scalar_ref_2 = shaped_array_ref((), jnp.float32) scalar_ref_2 = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) wrap_init(f, 2), [scalar_ref_1, scalar_ref_2])
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns) prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
@ -980,13 +987,14 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals]) wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
# vmap-of-discharge # vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals]) wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims), in_axes=(ref_bdim, *idx_bdims),
@ -1025,13 +1033,14 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[]) out_axes=[])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge # vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims), in_axes=(ref_bdim, val_bdim, *idx_bdims),
@ -1069,13 +1078,14 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[]) out_axes=[])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge # vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims), in_axes=(ref_bdim, val_bdim, *idx_bdims),
@ -1414,7 +1424,7 @@ class GeneralRefTest(jtu.JaxTestCase):
ref_addupdate(x_ref, (), x) ref_addupdate(x_ref, (), x)
return [x] return [x]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [AbstractRef(core.AbstractToken())]) wrap_init(f, 1), [AbstractRef(core.AbstractToken())])
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken) self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
def test_ref_of_ref(self): def test_ref_of_ref(self):
@ -1423,7 +1433,7 @@ class GeneralRefTest(jtu.JaxTestCase):
return [x_ref] return [x_ref]
# Not sure why you'd ever want to do this, but it works! # Not sure why you'd ever want to do this, but it works!
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), wrap_init(f, 1),
[AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))]) [AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))])
self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef) self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)
self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray) self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray)

View File

@ -17,6 +17,7 @@ import operator
from absl.testing import absltest from absl.testing import absltest
import jax import jax
from jax import api_util
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import util from jax._src import util
@ -62,7 +63,10 @@ class UtilTest(jtu.JaxTestCase):
store.store(aux_output) store.store(aux_output)
return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):]))) return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):])))
wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`. # Wraps `f` as a `WrappedFun`.
wf = lu.wrap_init(
f,
debug_info=api_util.debug_info("test", f, (1, 2), dict(three=3, four=4)))
wf, out_thunk = kw_to_positional(wf, 2) wf, out_thunk = kw_to_positional(wf, 2)
# Call the transformed function. # Call the transformed function.
scaled_positional, scaled_kwargs = wf.call_wrapped(1, 2, three=3, four=4) scaled_positional, scaled_kwargs = wf.call_wrapped(1, 2, three=3, four=4)