From 817b3e57574e6e2c2a08ab626746c445cbbae3a1 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 8 Feb 2025 15:19:46 +0200 Subject: [PATCH] [better_errors] Continue adding debug info to Jaxprs (step 7) This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, pallas, tests. --- jax/_src/api_util.py | 11 ++- jax/_src/interpreters/pxla.py | 2 +- jax/_src/pallas/core.py | 6 +- jax/_src/pallas/cost_estimate.py | 7 +- jax/_src/pallas/mosaic/interpret.py | 4 +- jax/_src/pallas/mosaic/lowering.py | 6 +- jax/_src/pallas/mosaic_gpu/pipeline.py | 8 +- jax/_src/pallas/pallas_call.py | 7 +- jax/_src/pallas/primitives.py | 6 +- jax/_src/pallas/triton/lowering.py | 17 ++++- jax/_src/state/discharge.py | 26 ++++--- jax/api_util.py | 1 + jax/experimental/jet.py | 7 +- jax/experimental/key_reuse/_core.py | 5 +- jax/experimental/ode.py | 12 ++- tests/jaxpr_effects_test.py | 6 +- tests/name_stack_test.py | 6 +- tests/pallas/ops_test.py | 14 +++- tests/state_test.py | 102 ++++++++++++++----------- tests/util_test.py | 6 +- 20 files changed, 172 insertions(+), 87 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 906c40542..86660d580 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -610,10 +610,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None: except (ValueError, TypeError): return None -def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable): +def save_wrapped_fun_sourceinfo(wrapper: Callable, + wrapped: Callable | core.DebugInfo | None) -> None: # Prefer this to functools.wraps because it does not create a reference to # the wrapped function. - setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped)) + if isinstance(wrapped, core.DebugInfo): + func_src_info = wrapped.func_src_info + elif callable(wrapped): + func_src_info = fun_sourceinfo(wrapped) + else: + return + setattr(wrapper, "__fun_sourceinfo__", func_src_info) _fun_name_re = re.compile(r"(?:)") diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7cbeeb080..f8fe90e81 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1823,7 +1823,7 @@ def _move_mutable_consts( invars = (*jaxpr.invars, *mutvars) effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, - effects, None) + effects, closed_jaxpr.jaxpr.debug_info) return core.ClosedJaxpr(jaxpr, consts), in_mut @weakref_lru_cache diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index dfda00c0c..c7e3d12fb 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1062,7 +1062,11 @@ def core_map( """ def wrapped(f): flat_args, in_tree = tree_util.tree_flatten(((), {})) - flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) + flat_fun, out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(f, + debug_info=api_util.debug_info("pallas_core_map", f, + (), {})), + in_tree) with jax_core.extend_axis_env_nd(mesh.shape.items()): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 5b322eedc..73db4a2e2 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -91,8 +91,11 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: """ flattened_args, treedef = jax.tree.flatten(args) partial_fun = functools.partial(fun, **kwargs) - wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun), - treedef) + wrapped_fun, _ = api_util.flatten_fun_nokwargs( + lu.wrap_init(partial_fun, + debug_info=api_util.debug_info("cost_estimate", fun, + args, kwargs)), + treedef) avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index aec9326a4..06e3705e4 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -476,7 +476,9 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params): return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args, compiler_params=compiler_params) in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) - new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals) + new_jaxpr = _to_jaxpr(lu.wrap_init(f, + debug_info=pjit_jaxpr.jaxpr.debug_info), + in_avals) out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) elif prim is primitives.run_scoped_p: diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 908b47604..786ece2db 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -23,6 +23,7 @@ import string from typing import Any, Hashable import jax +from jax import api_util from jax import lax from jax import tree_util from jax._src import ad_util @@ -845,7 +846,10 @@ def lower_jaxpr_to_func( def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable: def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) - wrapped_fun = lu.wrap_init(f, params) + wrapped_fun = lu.wrap_init( + f, params, + debug_info=api_util.debug_info("mosaic lower_fun", f, + args, params)) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) if consts: raise NotImplementedError diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 4166ac37c..568f3a169 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -24,6 +24,7 @@ import math from typing import Any import jax +from jax import api_util from jax import lax from jax._src import core from jax._src import linear_util as lu @@ -94,7 +95,12 @@ def _uses_arguments( index_map: Callable[..., Any], num_args: int ) -> Sequence[bool]: jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args + lu.wrap_init( + index_map, + debug_info=api_util.debug_info("pallas index_map", + index_map, + (0,) * num_args, {})), + (core.ShapedArray((), jnp.int32),) * num_args ) _, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars)) return used_inputs diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index ccef4152f..f3c5cd704 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1001,7 +1001,10 @@ def pallas_call_checkify_oob_grid(error: checkify.Error, ) flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),)) wrapped_loop, _ = api_util.flatten_fun_nokwargs( - lu.wrap_init(f), jaxpr_in_tree) + lu.wrap_init(f, + debug_info=api_util.debug_info("checkify oob_grid_access", + f, (0,), {})), + jaxpr_in_tree) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): avals_in = map(jax_core.get_aval, flat_args) traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic( @@ -1426,7 +1429,7 @@ def _pallas_call_state_discharge_rule( ) ) new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_rewritten_body), + lu.wrap_init(_rewritten_body, debug_info=jaxpr.debug_info), [ *index_map_avals, *ref_avals, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index ed22fdebe..57837a2be 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -848,7 +848,11 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: types in addition to :class:`jax.experimental.pallas.MemoryRef`. """ flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) - flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) + flat_fun, out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(f, + debug_info=api_util.debug_info("pallas run_scoped", + f, types, kw_types)), + in_tree) # We allow ref avals to be transformed references. ref_avals = [t.get_ref_aval() for t in flat_types] avals = [ diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 87a744cbb..3e80e40d9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -415,7 +415,10 @@ def lower_fun( fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) def f_lowered(ctx: LoweringRuleContext, *args, **params): - wrapped_fun = lu.wrap_init(fn, params) + wrapped_fun = lu.wrap_init( + fn, params, + debug_info=api_util.debug_info("pallas triton lower_fun", fun, + args, params)) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) @@ -539,7 +542,11 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes): ] in_tree = tree_util.tree_structure((args, args)) flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(body), in_tree + lu.wrap_init( + body, + debug_info=api_util.debug_info("pallas triton associative_scan", + body, (args, args), {})), + in_tree ) combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( flat_fun, in_avals @@ -2185,7 +2192,11 @@ def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes): mapped_avals = [jax_core.ShapedArray((), aval.dtype) for aval in ctx.avals_in] in_tree = tree_util.tree_structure((a, a)) flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(body), in_tree + lu.wrap_init( + body, + debug_info=api_util.debug_info("pallas triton reduction", + body, (a, a), {})), + in_tree ) combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( flat_fun, [*mapped_avals, *mapped_avals] diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 048217807..f75d8690c 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -460,12 +460,13 @@ def _addupdate_discharge(x, val, idx, tree): return _prepend_scatter(x, indexer, val, add=True) @weakref_lru_cache -def _cached_closed_jaxpr_discharge(closed_jaxpr): +def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr): jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts num_outs = len(jaxpr.outvars) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) - fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr)) + fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr), + debug_info=discharged_jaxpr.debug_info) return discharged_closed_jaxpr, num_outs, fun @register_discharge_rule(core.closed_call_p) @@ -598,7 +599,6 @@ def _convert_outputs_to_writes( assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." in_avals = [v.aval for v in jaxpr.invars] - @lu.wrap_init def eval_jaxpr(*refs): # We split the refs into the original input refs and the dummy residual # refs. @@ -610,14 +610,15 @@ def _convert_outputs_to_writes( res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef) else v.aval for v in jaxpr.outvars] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - eval_jaxpr, [*in_avals, *res_ref_avals]) + lu.wrap_init(eval_jaxpr, + debug_info=jaxpr.debug_info), + [*in_avals, *res_ref_avals]) assert not consts return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: assert not jaxpr.constvars, "Jaxpr should not have constvars" - @lu.wrap_init def eval_jaxpr(*refs): residual_refs, orig_refs = split_list(refs, [num_res]) residual_vals = [r[...] for r in residual_refs] @@ -629,7 +630,9 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else aval for aval in res_val_avals] jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( - eval_jaxpr, [*res_ref_avals, *orig_ref_avals]) + lu.wrap_init(eval_jaxpr, + debug_info=jaxpr.debug_info), + [*res_ref_avals, *orig_ref_avals]) return jaxpr def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, @@ -845,12 +848,13 @@ def _run_state_partial_eval_custom( *[v.aval for v in res_staged_invars], **staged_params) _, staged_outvars = partition_list(in_unknowns, eqn.outvars) if num_res: - @lu.wrap_init + def staged(*args): out = run_state_p.bind(*args, **staged_params) return out[num_res:] - staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged, - [v.aval for v in res_staged_invars]) + staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), + [v.aval for v in res_staged_invars]) eqn_staged = pe.new_jaxpr_eqn(res_staged_invars, staged_outvars, core.closed_call_p, @@ -918,7 +922,9 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool], ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ()) return [] jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) + lu.wrap_init(trans, + debug_info=jaxpr.debug_info), + [v.aval for v in jaxpr.invars]) return jaxpr_trans, consts def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, diff --git a/jax/api_util.py b/jax/api_util.py index 50d1c4268..78f984f4b 100644 --- a/jax/api_util.py +++ b/jax/api_util.py @@ -19,6 +19,7 @@ from jax._src.core import ( from jax._src.api_util import ( argnums_partial as argnums_partial, + debug_info as debug_info, donation_vector as donation_vector, flatten_axes as flatten_axes, flatten_fun as flatten_fun, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 9b3ce9ec8..31515d4e1 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -60,6 +60,7 @@ from functools import partial import numpy as np from jax import lax +from jax import api_util import jax.numpy as jnp from jax.experimental import pjit from jax.tree_util import (register_pytree_node, tree_structure, @@ -147,7 +148,9 @@ def jet(fun, primals, series): store.store(tree) return ans - f, out_tree = flatten_fun_output(lu.wrap_init(fun)) + f, out_tree = flatten_fun_output( + lu.wrap_init(fun, + debug_info=api_util.debug_info("jet", fun, primals, {}))) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) @@ -723,7 +726,7 @@ jet_rules[lax.scatter_add_p] = _scatter_add_rule def _jet_jaxpr( jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def ) -> tuple[core.ClosedJaxpr, Any]: - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def) jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic( f_jet, primals_and_series_avals) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 87c41e13c..7275046f5 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -397,7 +397,10 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature: def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature: args_flat, in_tree = tree_util.tree_flatten(args) in_avals_flat = [core.get_aval(arg) for arg in args_flat] - wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + wrapped_fun, _ = api_util.flatten_fun_nokwargs( + lu.wrap_init(fun, + debug_info=api_util.debug_info("key_reuse", fun, args, {})), + in_tree) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) return jaxpr_type_signature(jaxpr) diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index 987e461a3..db7865124 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -28,8 +28,10 @@ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf from functools import partial import operator as op +from typing import Callable import jax +from jax import api_util import jax.numpy as jnp from jax._src import core from jax import custom_derivatives @@ -44,8 +46,9 @@ map = safe_map zip = safe_zip -def ravel_first_arg(f, unravel): - return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped +def ravel_first_arg(f: Callable, unravel, debug_info: core.DebugInfo): + return ravel_first_arg_(lu.wrap_init(f, debug_info=debug_info), + unravel).call_wrapped @lu.transformation2 def ravel_first_arg_(f, unravel, y_flat, *args): @@ -179,9 +182,10 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jn return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts) @partial(jax.jit, static_argnums=(0, 1, 2, 3, 4)) -def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args): +def _odeint_wrapper(func: Callable, rtol, atol, mxstep, hmax, y0, ts, *args): y0, unravel = ravel_pytree(y0) - func = ravel_first_arg(func, unravel) + debug = api_util.debug_info("odeint", func, args, {}) + func = ravel_first_arg(func, unravel, debug) out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args) return jax.vmap(unravel)(out) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 2d63a4834..43c910f00 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -17,6 +17,7 @@ import unittest from absl.testing import absltest import jax +from jax import api_util import jax.numpy as jnp from jax import lax from jax.experimental import pjit @@ -178,12 +179,13 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): def test_core_call_primitive_inherits_effects(self): def f(x): - @lu.wrap_init def f_(x): effect_p.bind(effect=foo_effect) effect_p.bind(effect=bar_effect) return [x] - return core.call(f_, x)[0] + dbg = api_util.debug_info("test", f_, (2.,), {}) + return core.call( + lu.wrap_init(f_, debug_info=dbg), x)[0] jaxpr = jax.make_jaxpr(f)(2.) self.assertIn(foo_effect, jaxpr.jaxpr.effects) self.assertIn(bar_effect, jaxpr.jaxpr.effects) diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index d59c5682e..270707934 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -15,6 +15,7 @@ import functools from absl.testing import absltest import jax +from jax import api_util import jax.numpy as jnp from jax._src import core from jax import lax @@ -85,11 +86,12 @@ class NameStackTest(jtu.JaxTestCase): def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self): @jax.named_scope('foo') def f(x): - @lu.wrap_init @jax.named_scope('bar') def _f(x): return [x + 1] - return core.call(_f, x)[0] + return core.call(lu.wrap_init( + _f, + debug_info=api_util.debug_info("test", _f, (0,), {})), x)[0] jaxpr = jax.make_jaxpr(f)(2).jaxpr self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index a0c891c6f..0495e6194 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -19,12 +19,13 @@ import functools import itertools import math import sys -from typing import Any +from typing import Any, Callable import unittest from absl.testing import absltest from absl.testing import parameterized import jax +from jax import api_util from jax import lax from jax import random from jax._src import dtypes @@ -61,6 +62,11 @@ jtu.setup_hypothesis(max_examples=50) intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) +def wrap_init(f: Callable, nr_args: int): + # wrapper for lu.wrap_init with debugging info + return lu.wrap_init( + f, + debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {})) def is_power_of_two(n: int) -> bool: return (n > 0) and (n & (n - 1) == 0) @@ -2284,7 +2290,7 @@ class PallasPrimitivesTest(PallasBaseTest): x = pl.load(x_ref, expr()) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) + wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) @parameterized.parameters(*[ @@ -2299,7 +2305,7 @@ class PallasPrimitivesTest(PallasBaseTest): pl.store(x_ref, expr(), pl.load(x_ref, expr())) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) + wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) @parameterized.parameters(*[ @@ -2319,7 +2325,7 @@ class PallasPrimitivesTest(PallasBaseTest): x = pl.swap(x_ref, expr(), pl.load(x_ref, expr())) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) + wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) diff --git a/tests/state_test.py b/tests/state_test.py index 0fc37ba47..60a7d8bc9 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np import jax +from jax import api_util from jax import random from jax import lax from jax._src import core @@ -54,6 +55,12 @@ from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect, config.parse_flags_with_absl() jtu.setup_hypothesis() +def wrap_init(f: Callable, nr_args: int): + # wrapper for lu.wrap_init with debugging info + return lu.wrap_init( + f, + debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {})) + class StatePrimitivesTest(jtu.JaxTestCase): @@ -62,7 +69,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): def f(x_ref): return [ref_get(x_ref, ())] with self.assertRaises(ValueError): - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval]) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval]) @parameterized.named_parameters( dict(testcase_name="trivial_get", ref_shape=(1, 2), @@ -121,10 +128,10 @@ class StatePrimitivesTest(jtu.JaxTestCase): return [out] if should_error: with self.assertRaises(Exception): - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval]) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval]) else: jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval]) + wrap_init(f, 1), [ref_aval]) self.assertSetEqual(jaxpr.effects, {ReadEffect(len(jaxpr.constvars))}) self.assertLen(out_avals, 1) @@ -139,7 +146,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): def f(x_ref, val): return [ref_swap(x_ref, (), val)] with self.assertRaises(ValueError): - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval]) @parameterized.named_parameters( dict(testcase_name="invalid_val_shape", ref_shape=(1, 2), @@ -218,10 +225,10 @@ class StatePrimitivesTest(jtu.JaxTestCase): return [out] if should_error: with self.assertRaises(Exception): - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval]) else: jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval, val_aval]) + wrap_init(f, 2), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, {WriteEffect(len(jaxpr.constvars))}) self.assertLen(out_avals, 1) @@ -295,10 +302,10 @@ class StatePrimitivesTest(jtu.JaxTestCase): return [] if should_error: with self.assertRaises(Exception): - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval]) else: jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval, val_aval]) + wrap_init(f, 2), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, {AccumEffect(len(jaxpr.constvars))}) self.assertLen(out_avals, 0) @@ -309,7 +316,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): def f(x_ref, val): return [ref_addupdate(x_ref, (), val)] with self.assertRaises(ValueError): - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval]) def test_can_represent_get_and_swap_in_jaxprs(self): @@ -318,7 +325,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): x[()] = jnp.int32(2) return (x[()],) jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) self.assertEqual(jaxpr.eqns[0].primitive, swap_p) @@ -331,7 +338,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): ref_addupdate(x, (), jnp.int32(1)) return (x[()],) jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) @@ -341,14 +348,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): x = x_ref[()] return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) def body(x_ref): x = x_ref[:, 0] return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((1, 2), jnp.int32)]) self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False)) def test_set_custom_pretty_printing_rule(self): @@ -356,14 +363,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): x_ref[()] = jnp.int32(2) return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), + wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False)) @@ -372,14 +379,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): x = ref_swap(x_ref, (), jnp.int32(2)) return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x = ref_swap(x_ref, (slice(None), 0), val) return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), + wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("c:i32[1], a[:,0] <- a[:,0], b", jaxpr.pretty_print(use_color=False)) @@ -389,7 +396,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): ref_addupdate(x_ref, (), jnp.int32(2)) return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) + wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) @@ -397,7 +404,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): ref_addupdate(x_ref, (slice(None), 0), val) return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), + wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False)) @@ -413,7 +420,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p) @@ -429,7 +436,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p) self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) @@ -449,7 +456,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[2].primitive, get_p) @@ -521,12 +528,12 @@ class StatePrimitivesTest(jtu.JaxTestCase): # discharge-of-vmap f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals]) + wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval, *idx_avals]) + wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, *idx_bdims), @@ -544,7 +551,7 @@ class StateDischargeTest(jtu.JaxTestCase): a = ref_get(a_ref, ()) return [a + 1] in_avals = [shaped_array_ref((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) @@ -560,7 +567,7 @@ class StateDischargeTest(jtu.JaxTestCase): a = ref_get(a_ref, (0, 1)) return [a + 1] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) @@ -580,7 +587,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [a + 1] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), in_avals) + wrap_init(f, 1), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3)) @@ -594,7 +601,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [] in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that ignores the first # value and returns second value plus 1. @@ -611,7 +618,7 @@ class StateDischargeTest(jtu.JaxTestCase): ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) return [] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) @@ -631,7 +638,7 @@ class StateDischargeTest(jtu.JaxTestCase): a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32') return [] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -648,7 +655,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [a + 1] in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)] stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), in_avals) + wrap_init(f, 1), in_avals) discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -665,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [] in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) # Discharging should just turn this into a jaxpr that adds the first value, # second value, and 1. @@ -683,7 +690,7 @@ class StateDischargeTest(jtu.JaxTestCase): jnp.ones(2, dtype=jnp.dtype('float32'))) return [] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -704,7 +711,7 @@ class StateDischargeTest(jtu.JaxTestCase): jnp.ones((2, 3), 'float32')) return [] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -718,7 +725,7 @@ class StateDischargeTest(jtu.JaxTestCase): b = a + 1 return [a, b] in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -738,7 +745,7 @@ class StateDischargeTest(jtu.JaxTestCase): shaped_array_ref((4,), jnp.dtype('float32')), shaped_array_ref((4,), jnp.dtype('float32')) ] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) discharged_jaxpr, _ = discharge_state( stateful_jaxpr, consts, should_discharge=[False, True]) @@ -758,7 +765,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [] in_avals = [shaped_array_ref((), jnp.float32)] - pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) + pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) def test_partial_discharge(self): def f(a_ref, b_ref): @@ -769,7 +776,7 @@ class StateDischargeTest(jtu.JaxTestCase): scalar_ref_1 = shaped_array_ref((), jnp.float32) scalar_ref_2 = shaped_array_ref((), jnp.float32) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) + wrap_init(f, 2), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns) @@ -980,13 +987,14 @@ if CAN_USE_HYPOTHESIS: f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals]) + wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, *bat_non_slice_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval, *idx_avals]) + wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, *idx_bdims), @@ -1025,13 +1033,14 @@ if CAN_USE_HYPOTHESIS: f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), out_axes=[]) stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, val_bdim, *idx_bdims), @@ -1069,13 +1078,14 @@ if CAN_USE_HYPOTHESIS: f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), out_axes=[]) stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, val_bdim, *idx_bdims), @@ -1414,7 +1424,7 @@ class GeneralRefTest(jtu.JaxTestCase): ref_addupdate(x_ref, (), x) return [x] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [AbstractRef(core.AbstractToken())]) + wrap_init(f, 1), [AbstractRef(core.AbstractToken())]) self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken) def test_ref_of_ref(self): @@ -1423,7 +1433,7 @@ class GeneralRefTest(jtu.JaxTestCase): return [x_ref] # Not sure why you'd ever want to do this, but it works! jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), + wrap_init(f, 1), [AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))]) self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef) self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray) diff --git a/tests/util_test.py b/tests/util_test.py index 5e99fff4b..cb803d66b 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -17,6 +17,7 @@ import operator from absl.testing import absltest import jax +from jax import api_util from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util @@ -62,7 +63,10 @@ class UtilTest(jtu.JaxTestCase): store.store(aux_output) return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):]))) - wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`. + # Wraps `f` as a `WrappedFun`. + wf = lu.wrap_init( + f, + debug_info=api_util.debug_info("test", f, (1, 2), dict(three=3, four=4))) wf, out_thunk = kw_to_positional(wf, 2) # Call the transformed function. scaled_positional, scaled_kwargs = wf.call_wrapped(1, 2, three=3, four=4)