[better_errors] Continue adding debug info to Jaxprs (step 7)

This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
This commit is contained in:
George Necula 2025-02-08 15:19:46 +02:00
parent 8401d9b752
commit 817b3e5757
20 changed files with 172 additions and 87 deletions

View File

@ -610,10 +610,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
except (ValueError, TypeError):
return None
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
def save_wrapped_fun_sourceinfo(wrapper: Callable,
wrapped: Callable | core.DebugInfo | None) -> None:
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
if isinstance(wrapped, core.DebugInfo):
func_src_info = wrapped.func_src_info
elif callable(wrapped):
func_src_info = fun_sourceinfo(wrapped)
else:
return
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")

View File

@ -1823,7 +1823,7 @@ def _move_mutable_consts(
invars = (*jaxpr.invars, *mutvars)
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
effects, None)
effects, closed_jaxpr.jaxpr.debug_info)
return core.ClosedJaxpr(jaxpr, consts), in_mut
@weakref_lru_cache

View File

@ -1062,7 +1062,11 @@ def core_map(
"""
def wrapped(f):
flat_args, in_tree = tree_util.tree_flatten(((), {}))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f,
debug_info=api_util.debug_info("pallas_core_map", f,
(), {})),
in_tree)
with jax_core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh,

View File

@ -91,8 +91,11 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
"""
flattened_args, treedef = jax.tree.flatten(args)
partial_fun = functools.partial(fun, **kwargs)
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun),
treedef)
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(partial_fun,
debug_info=api_util.debug_info("cost_estimate", fun,
args, kwargs)),
treedef)
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))

View File

@ -476,7 +476,9 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
compiler_params=compiler_params)
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals)
new_jaxpr = _to_jaxpr(lu.wrap_init(f,
debug_info=pjit_jaxpr.jaxpr.debug_info),
in_avals)
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
elif prim is primitives.run_scoped_p:

View File

@ -23,6 +23,7 @@ import string
from typing import Any, Hashable
import jax
from jax import api_util
from jax import lax
from jax import tree_util
from jax._src import ad_util
@ -845,7 +846,10 @@ def lower_jaxpr_to_func(
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
wrapped_fun = lu.wrap_init(
f, params,
debug_info=api_util.debug_info("mosaic lower_fun", f,
args, params))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
if consts:
raise NotImplementedError

View File

@ -24,6 +24,7 @@ import math
from typing import Any
import jax
from jax import api_util
from jax import lax
from jax._src import core
from jax._src import linear_util as lu
@ -94,7 +95,12 @@ def _uses_arguments(
index_map: Callable[..., Any], num_args: int
) -> Sequence[bool]:
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args
lu.wrap_init(
index_map,
debug_info=api_util.debug_info("pallas index_map",
index_map,
(0,) * num_args, {})),
(core.ShapedArray((), jnp.int32),) * num_args
)
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
return used_inputs

View File

@ -1001,7 +1001,10 @@ def pallas_call_checkify_oob_grid(error: checkify.Error,
)
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
wrapped_loop, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(f), jaxpr_in_tree)
lu.wrap_init(f,
debug_info=api_util.debug_info("checkify oob_grid_access",
f, (0,), {})),
jaxpr_in_tree)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
avals_in = map(jax_core.get_aval, flat_args)
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
@ -1426,7 +1429,7 @@ def _pallas_call_state_discharge_rule(
)
)
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_rewritten_body),
lu.wrap_init(_rewritten_body, debug_info=jaxpr.debug_info),
[
*index_map_avals,
*ref_avals,

View File

@ -848,7 +848,11 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
types in addition to :class:`jax.experimental.pallas.MemoryRef`.
"""
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f,
debug_info=api_util.debug_info("pallas run_scoped",
f, types, kw_types)),
in_tree)
# We allow ref avals to be transformed references.
ref_avals = [t.get_ref_aval() for t in flat_types]
avals = [

View File

@ -415,7 +415,10 @@ def lower_fun(
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
def f_lowered(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
wrapped_fun = lu.wrap_init(
fn, params,
debug_info=api_util.debug_info("pallas triton lower_fun", fun,
args, params))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
@ -539,7 +542,11 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes):
]
in_tree = tree_util.tree_structure((args, args))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
lu.wrap_init(
body,
debug_info=api_util.debug_info("pallas triton associative_scan",
body, (args, args), {})),
in_tree
)
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, in_avals
@ -2185,7 +2192,11 @@ def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
mapped_avals = [jax_core.ShapedArray((), aval.dtype) for aval in ctx.avals_in]
in_tree = tree_util.tree_structure((a, a))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
lu.wrap_init(
body,
debug_info=api_util.debug_info("pallas triton reduction",
body, (a, a), {})),
in_tree
)
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, [*mapped_avals, *mapped_avals]

View File

@ -460,12 +460,13 @@ def _addupdate_discharge(x, val, idx, tree):
return _prepend_scatter(x, indexer, val, add=True)
@weakref_lru_cache
def _cached_closed_jaxpr_discharge(closed_jaxpr):
def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr):
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts
num_outs = len(jaxpr.outvars)
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr))
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr),
debug_info=discharged_jaxpr.debug_info)
return discharged_closed_jaxpr, num_outs, fun
@register_discharge_rule(core.closed_call_p)
@ -598,7 +599,6 @@ def _convert_outputs_to_writes(
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
in_avals = [v.aval for v in jaxpr.invars]
@lu.wrap_init
def eval_jaxpr(*refs):
# We split the refs into the original input refs and the dummy residual
# refs.
@ -610,14 +610,15 @@ def _convert_outputs_to_writes(
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
else v.aval for v in jaxpr.outvars]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info),
[*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars"
@lu.wrap_init
def eval_jaxpr(*refs):
residual_refs, orig_refs = split_list(refs, [num_res])
residual_vals = [r[...] for r in residual_refs]
@ -629,7 +630,9 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
aval for aval in res_val_avals]
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info),
[*res_ref_avals, *orig_ref_avals])
return jaxpr
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
@ -845,12 +848,13 @@ def _run_state_partial_eval_custom(
*[v.aval for v in res_staged_invars], **staged_params)
_, staged_outvars = partition_list(in_unknowns, eqn.outvars)
if num_res:
@lu.wrap_init
def staged(*args):
out = run_state_p.bind(*args, **staged_params)
return out[num_res:]
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
[v.aval for v in res_staged_invars])
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info),
[v.aval for v in res_staged_invars])
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
staged_outvars,
core.closed_call_p,
@ -918,7 +922,9 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool],
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
return []
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
lu.wrap_init(trans,
debug_info=jaxpr.debug_info),
[v.aval for v in jaxpr.invars])
return jaxpr_trans, consts
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,

View File

@ -19,6 +19,7 @@ from jax._src.core import (
from jax._src.api_util import (
argnums_partial as argnums_partial,
debug_info as debug_info,
donation_vector as donation_vector,
flatten_axes as flatten_axes,
flatten_fun as flatten_fun,

View File

@ -60,6 +60,7 @@ from functools import partial
import numpy as np
from jax import lax
from jax import api_util
import jax.numpy as jnp
from jax.experimental import pjit
from jax.tree_util import (register_pytree_node, tree_structure,
@ -147,7 +148,9 @@ def jet(fun, primals, series):
store.store(tree)
return ans
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
f, out_tree = flatten_fun_output(
lu.wrap_init(fun,
debug_info=api_util.debug_info("jet", fun, primals, {})))
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
@ -723,7 +726,7 @@ jet_rules[lax.scatter_add_p] = _scatter_add_rule
def _jet_jaxpr(
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
) -> tuple[core.ClosedJaxpr, Any]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic(
f_jet, primals_and_series_avals)

View File

@ -397,7 +397,10 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
args_flat, in_tree = tree_util.tree_flatten(args)
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=api_util.debug_info("key_reuse", fun, args, {})),
in_tree)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
return jaxpr_type_signature(jaxpr)

View File

@ -28,8 +28,10 @@ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
from functools import partial
import operator as op
from typing import Callable
import jax
from jax import api_util
import jax.numpy as jnp
from jax._src import core
from jax import custom_derivatives
@ -44,8 +46,9 @@ map = safe_map
zip = safe_zip
def ravel_first_arg(f, unravel):
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
def ravel_first_arg(f: Callable, unravel, debug_info: core.DebugInfo):
return ravel_first_arg_(lu.wrap_init(f, debug_info=debug_info),
unravel).call_wrapped
@lu.transformation2
def ravel_first_arg_(f, unravel, y_flat, *args):
@ -179,9 +182,10 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jn
return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts)
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args):
def _odeint_wrapper(func: Callable, rtol, atol, mxstep, hmax, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
func = ravel_first_arg(func, unravel)
debug = api_util.debug_info("odeint", func, args, {})
func = ravel_first_arg(func, unravel, debug)
out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
return jax.vmap(unravel)(out)

View File

@ -17,6 +17,7 @@ import unittest
from absl.testing import absltest
import jax
from jax import api_util
import jax.numpy as jnp
from jax import lax
from jax.experimental import pjit
@ -178,12 +179,13 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
def test_core_call_primitive_inherits_effects(self):
def f(x):
@lu.wrap_init
def f_(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return [x]
return core.call(f_, x)[0]
dbg = api_util.debug_info("test", f_, (2.,), {})
return core.call(
lu.wrap_init(f_, debug_info=dbg), x)[0]
jaxpr = jax.make_jaxpr(f)(2.)
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects)

View File

@ -15,6 +15,7 @@ import functools
from absl.testing import absltest
import jax
from jax import api_util
import jax.numpy as jnp
from jax._src import core
from jax import lax
@ -85,11 +86,12 @@ class NameStackTest(jtu.JaxTestCase):
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
@jax.named_scope('foo')
def f(x):
@lu.wrap_init
@jax.named_scope('bar')
def _f(x):
return [x + 1]
return core.call(_f, x)[0]
return core.call(lu.wrap_init(
_f,
debug_info=api_util.debug_info("test", _f, (0,), {})), x)[0]
jaxpr = jax.make_jaxpr(f)(2).jaxpr
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')

View File

@ -19,12 +19,13 @@ import functools
import itertools
import math
import sys
from typing import Any
from typing import Any, Callable
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import api_util
from jax import lax
from jax import random
from jax._src import dtypes
@ -61,6 +62,11 @@ jtu.setup_hypothesis(max_examples=50)
intx = dtypes.canonicalize_dtype(jnp.int64)
floatx = dtypes.canonicalize_dtype(jnp.float64)
def wrap_init(f: Callable, nr_args: int):
# wrapper for lu.wrap_init with debugging info
return lu.wrap_init(
f,
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
def is_power_of_two(n: int) -> bool:
return (n > 0) and (n & (n - 1) == 0)
@ -2284,7 +2290,7 @@ class PallasPrimitivesTest(PallasBaseTest):
x = pl.load(x_ref, expr())
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
@parameterized.parameters(*[
@ -2299,7 +2305,7 @@ class PallasPrimitivesTest(PallasBaseTest):
pl.store(x_ref, expr(), pl.load(x_ref, expr()))
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
@parameterized.parameters(*[
@ -2319,7 +2325,7 @@ class PallasPrimitivesTest(PallasBaseTest):
x = pl.swap(x_ref, expr(), pl.load(x_ref, expr()))
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
self.assertIn(expected, jaxpr.pretty_print(use_color=False))

View File

@ -23,6 +23,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import api_util
from jax import random
from jax import lax
from jax._src import core
@ -54,6 +55,12 @@ from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
config.parse_flags_with_absl()
jtu.setup_hypothesis()
def wrap_init(f: Callable, nr_args: int):
# wrapper for lu.wrap_init with debugging info
return lu.wrap_init(
f,
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
class StatePrimitivesTest(jtu.JaxTestCase):
@ -62,7 +69,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def f(x_ref):
return [ref_get(x_ref, ())]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
@parameterized.named_parameters(
dict(testcase_name="trivial_get", ref_shape=(1, 2),
@ -121,10 +128,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return [out]
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval])
wrap_init(f, 1), [ref_aval])
self.assertSetEqual(jaxpr.effects,
{ReadEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 1)
@ -139,7 +146,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def f(x_ref, val):
return [ref_swap(x_ref, (), val)]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
@parameterized.named_parameters(
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
@ -218,10 +225,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return [out]
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
wrap_init(f, 2), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects,
{WriteEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 1)
@ -295,10 +302,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return []
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
wrap_init(f, 2), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects,
{AccumEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 0)
@ -309,7 +316,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def f(x_ref, val):
return [ref_addupdate(x_ref, (), val)]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
def test_can_represent_get_and_swap_in_jaxprs(self):
@ -318,7 +325,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, swap_p)
@ -331,7 +338,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
@ -341,14 +348,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x = x_ref[()]
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def body(x_ref):
x = x_ref[:, 0]
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((1, 2), jnp.int32)])
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self):
@ -356,14 +363,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x_ref[:, 0] = val
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
@ -372,14 +379,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x = ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x = ref_swap(x_ref, (slice(None), 0), val)
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
jaxpr.pretty_print(use_color=False))
@ -389,7 +396,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
@ -397,7 +404,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_addupdate(x_ref, (slice(None), 0), val)
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
@ -413,7 +420,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
@ -429,7 +436,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
@ -449,7 +456,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, get_p)
@ -521,12 +528,12 @@ class StatePrimitivesTest(jtu.JaxTestCase):
# discharge-of-vmap
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals])
wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals])
wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims),
@ -544,7 +551,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = ref_get(a_ref, ())
return [a + 1]
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
@ -560,7 +567,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
@ -580,7 +587,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [a + 1]
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals)
wrap_init(f, 1), in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
@ -594,7 +601,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1.
@ -611,7 +618,7 @@ class StateDischargeTest(jtu.JaxTestCase):
ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
@ -631,7 +638,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
@ -648,7 +655,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals)
wrap_init(f, 1), in_avals)
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
@ -665,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals)
# Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1.
@ -683,7 +690,7 @@ class StateDischargeTest(jtu.JaxTestCase):
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
@ -704,7 +711,7 @@ class StateDischargeTest(jtu.JaxTestCase):
jnp.ones((2, 3), 'float32'))
return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
@ -718,7 +725,7 @@ class StateDischargeTest(jtu.JaxTestCase):
b = a + 1
return [a, b]
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
@ -738,7 +745,7 @@ class StateDischargeTest(jtu.JaxTestCase):
shaped_array_ref((4,), jnp.dtype('float32')),
shaped_array_ref((4,), jnp.dtype('float32'))
]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals)
discharged_jaxpr, _ = discharge_state(
stateful_jaxpr, consts, should_discharge=[False, True])
@ -758,7 +765,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return []
in_avals = [shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals)
def test_partial_discharge(self):
def f(a_ref, b_ref):
@ -769,7 +776,7 @@ class StateDischargeTest(jtu.JaxTestCase):
scalar_ref_1 = shaped_array_ref((), jnp.float32)
scalar_ref_2 = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
wrap_init(f, 2), [scalar_ref_1, scalar_ref_2])
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
@ -980,13 +987,14 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals])
wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals])
wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims),
@ -1025,13 +1033,14 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims),
@ -1069,13 +1078,14 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims),
@ -1414,7 +1424,7 @@ class GeneralRefTest(jtu.JaxTestCase):
ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [AbstractRef(core.AbstractToken())])
wrap_init(f, 1), [AbstractRef(core.AbstractToken())])
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
def test_ref_of_ref(self):
@ -1423,7 +1433,7 @@ class GeneralRefTest(jtu.JaxTestCase):
return [x_ref]
# Not sure why you'd ever want to do this, but it works!
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f),
wrap_init(f, 1),
[AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))])
self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)
self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray)

View File

@ -17,6 +17,7 @@ import operator
from absl.testing import absltest
import jax
from jax import api_util
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import util
@ -62,7 +63,10 @@ class UtilTest(jtu.JaxTestCase):
store.store(aux_output)
return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):])))
wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
# Wraps `f` as a `WrappedFun`.
wf = lu.wrap_init(
f,
debug_info=api_util.debug_info("test", f, (1, 2), dict(three=3, four=4)))
wf, out_thunk = kw_to_positional(wf, 2)
# Call the transformed function.
scaled_positional, scaled_kwargs = wf.call_wrapped(1, 2, three=3, four=4)