mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
This commit is contained in:
parent
8401d9b752
commit
817b3e5757
@ -610,10 +610,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
|
||||
def save_wrapped_fun_sourceinfo(wrapper: Callable,
|
||||
wrapped: Callable | core.DebugInfo | None) -> None:
|
||||
# Prefer this to functools.wraps because it does not create a reference to
|
||||
# the wrapped function.
|
||||
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
|
||||
if isinstance(wrapped, core.DebugInfo):
|
||||
func_src_info = wrapped.func_src_info
|
||||
elif callable(wrapped):
|
||||
func_src_info = fun_sourceinfo(wrapped)
|
||||
else:
|
||||
return
|
||||
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
|
||||
|
||||
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
|
||||
|
||||
|
@ -1823,7 +1823,7 @@ def _move_mutable_consts(
|
||||
invars = (*jaxpr.invars, *mutvars)
|
||||
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
|
||||
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
|
||||
effects, None)
|
||||
effects, closed_jaxpr.jaxpr.debug_info)
|
||||
return core.ClosedJaxpr(jaxpr, consts), in_mut
|
||||
|
||||
@weakref_lru_cache
|
||||
|
@ -1062,7 +1062,11 @@ def core_map(
|
||||
"""
|
||||
def wrapped(f):
|
||||
flat_args, in_tree = tree_util.tree_flatten(((), {}))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f,
|
||||
debug_info=api_util.debug_info("pallas_core_map", f,
|
||||
(), {})),
|
||||
in_tree)
|
||||
with jax_core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
|
||||
out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh,
|
||||
|
@ -91,8 +91,11 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
|
||||
"""
|
||||
flattened_args, treedef = jax.tree.flatten(args)
|
||||
partial_fun = functools.partial(fun, **kwargs)
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun),
|
||||
treedef)
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(partial_fun,
|
||||
debug_info=api_util.debug_info("cost_estimate", fun,
|
||||
args, kwargs)),
|
||||
treedef)
|
||||
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))
|
||||
|
@ -476,7 +476,9 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
|
||||
compiler_params=compiler_params)
|
||||
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
|
||||
new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals)
|
||||
new_jaxpr = _to_jaxpr(lu.wrap_init(f,
|
||||
debug_info=pjit_jaxpr.jaxpr.debug_info),
|
||||
in_avals)
|
||||
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
|
||||
|
||||
elif prim is primitives.run_scoped_p:
|
||||
|
@ -23,6 +23,7 @@ import string
|
||||
from typing import Any, Hashable
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
@ -845,7 +846,10 @@ def lower_jaxpr_to_func(
|
||||
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
wrapped_fun = lu.wrap_init(f, params)
|
||||
wrapped_fun = lu.wrap_init(
|
||||
f, params,
|
||||
debug_info=api_util.debug_info("mosaic lower_fun", f,
|
||||
args, params))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
if consts:
|
||||
raise NotImplementedError
|
||||
|
@ -24,6 +24,7 @@ import math
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
@ -94,7 +95,12 @@ def _uses_arguments(
|
||||
index_map: Callable[..., Any], num_args: int
|
||||
) -> Sequence[bool]:
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args
|
||||
lu.wrap_init(
|
||||
index_map,
|
||||
debug_info=api_util.debug_info("pallas index_map",
|
||||
index_map,
|
||||
(0,) * num_args, {})),
|
||||
(core.ShapedArray((), jnp.int32),) * num_args
|
||||
)
|
||||
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
|
||||
return used_inputs
|
||||
|
@ -1001,7 +1001,10 @@ def pallas_call_checkify_oob_grid(error: checkify.Error,
|
||||
)
|
||||
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
|
||||
wrapped_loop, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(f), jaxpr_in_tree)
|
||||
lu.wrap_init(f,
|
||||
debug_info=api_util.debug_info("checkify oob_grid_access",
|
||||
f, (0,), {})),
|
||||
jaxpr_in_tree)
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
avals_in = map(jax_core.get_aval, flat_args)
|
||||
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
@ -1426,7 +1429,7 @@ def _pallas_call_state_discharge_rule(
|
||||
)
|
||||
)
|
||||
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_rewritten_body),
|
||||
lu.wrap_init(_rewritten_body, debug_info=jaxpr.debug_info),
|
||||
[
|
||||
*index_map_avals,
|
||||
*ref_avals,
|
||||
|
@ -848,7 +848,11 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
|
||||
types in addition to :class:`jax.experimental.pallas.MemoryRef`.
|
||||
"""
|
||||
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f,
|
||||
debug_info=api_util.debug_info("pallas run_scoped",
|
||||
f, types, kw_types)),
|
||||
in_tree)
|
||||
# We allow ref avals to be transformed references.
|
||||
ref_avals = [t.get_ref_aval() for t in flat_types]
|
||||
avals = [
|
||||
|
@ -415,7 +415,10 @@ def lower_fun(
|
||||
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
wrapped_fun = lu.wrap_init(fn, params)
|
||||
wrapped_fun = lu.wrap_init(
|
||||
fn, params,
|
||||
debug_info=api_util.debug_info("pallas triton lower_fun", fun,
|
||||
args, params))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
|
||||
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
|
||||
@ -539,7 +542,11 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes):
|
||||
]
|
||||
in_tree = tree_util.tree_structure((args, args))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(body), in_tree
|
||||
lu.wrap_init(
|
||||
body,
|
||||
debug_info=api_util.debug_info("pallas triton associative_scan",
|
||||
body, (args, args), {})),
|
||||
in_tree
|
||||
)
|
||||
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_fun, in_avals
|
||||
@ -2185,7 +2192,11 @@ def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
|
||||
mapped_avals = [jax_core.ShapedArray((), aval.dtype) for aval in ctx.avals_in]
|
||||
in_tree = tree_util.tree_structure((a, a))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(body), in_tree
|
||||
lu.wrap_init(
|
||||
body,
|
||||
debug_info=api_util.debug_info("pallas triton reduction",
|
||||
body, (a, a), {})),
|
||||
in_tree
|
||||
)
|
||||
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_fun, [*mapped_avals, *mapped_avals]
|
||||
|
@ -460,12 +460,13 @@ def _addupdate_discharge(x, val, idx, tree):
|
||||
return _prepend_scatter(x, indexer, val, add=True)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_closed_jaxpr_discharge(closed_jaxpr):
|
||||
def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr):
|
||||
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts
|
||||
num_outs = len(jaxpr.outvars)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
|
||||
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr))
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr),
|
||||
debug_info=discharged_jaxpr.debug_info)
|
||||
return discharged_closed_jaxpr, num_outs, fun
|
||||
|
||||
@register_discharge_rule(core.closed_call_p)
|
||||
@ -598,7 +599,6 @@ def _convert_outputs_to_writes(
|
||||
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
||||
|
||||
in_avals = [v.aval for v in jaxpr.invars]
|
||||
@lu.wrap_init
|
||||
def eval_jaxpr(*refs):
|
||||
# We split the refs into the original input refs and the dummy residual
|
||||
# refs.
|
||||
@ -610,14 +610,15 @@ def _convert_outputs_to_writes(
|
||||
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
|
||||
else v.aval for v in jaxpr.outvars]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||||
lu.wrap_init(eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info),
|
||||
[*in_avals, *res_ref_avals])
|
||||
assert not consts
|
||||
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]
|
||||
|
||||
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
assert not jaxpr.constvars, "Jaxpr should not have constvars"
|
||||
|
||||
@lu.wrap_init
|
||||
def eval_jaxpr(*refs):
|
||||
residual_refs, orig_refs = split_list(refs, [num_res])
|
||||
residual_vals = [r[...] for r in residual_refs]
|
||||
@ -629,7 +630,9 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
|
||||
aval for aval in res_val_avals]
|
||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
|
||||
lu.wrap_init(eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info),
|
||||
[*res_ref_avals, *orig_ref_avals])
|
||||
return jaxpr
|
||||
|
||||
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||||
@ -845,12 +848,13 @@ def _run_state_partial_eval_custom(
|
||||
*[v.aval for v in res_staged_invars], **staged_params)
|
||||
_, staged_outvars = partition_list(in_unknowns, eqn.outvars)
|
||||
if num_res:
|
||||
@lu.wrap_init
|
||||
|
||||
def staged(*args):
|
||||
out = run_state_p.bind(*args, **staged_params)
|
||||
return out[num_res:]
|
||||
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
|
||||
[v.aval for v in res_staged_invars])
|
||||
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info),
|
||||
[v.aval for v in res_staged_invars])
|
||||
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
|
||||
staged_outvars,
|
||||
core.closed_call_p,
|
||||
@ -918,7 +922,9 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool],
|
||||
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
|
||||
return []
|
||||
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
||||
lu.wrap_init(trans,
|
||||
debug_info=jaxpr.debug_info),
|
||||
[v.aval for v in jaxpr.invars])
|
||||
return jaxpr_trans, consts
|
||||
|
||||
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,
|
||||
|
@ -19,6 +19,7 @@ from jax._src.core import (
|
||||
|
||||
from jax._src.api_util import (
|
||||
argnums_partial as argnums_partial,
|
||||
debug_info as debug_info,
|
||||
donation_vector as donation_vector,
|
||||
flatten_axes as flatten_axes,
|
||||
flatten_fun as flatten_fun,
|
||||
|
@ -60,6 +60,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import api_util
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
@ -147,7 +148,9 @@ def jet(fun, primals, series):
|
||||
store.store(tree)
|
||||
return ans
|
||||
|
||||
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
|
||||
f, out_tree = flatten_fun_output(
|
||||
lu.wrap_init(fun,
|
||||
debug_info=api_util.debug_info("jet", fun, primals, {})))
|
||||
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
|
||||
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
|
||||
|
||||
@ -723,7 +726,7 @@ jet_rules[lax.scatter_add_p] = _scatter_add_rule
|
||||
def _jet_jaxpr(
|
||||
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
|
||||
) -> tuple[core.ClosedJaxpr, Any]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)
|
||||
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
|
||||
jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
f_jet, primals_and_series_avals)
|
||||
|
@ -397,7 +397,10 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
|
||||
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun,
|
||||
debug_info=api_util.debug_info("key_reuse", fun, args, {})),
|
||||
in_tree)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
return jaxpr_type_signature(jaxpr)
|
||||
|
||||
|
@ -28,8 +28,10 @@ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
|
||||
|
||||
from functools import partial
|
||||
import operator as op
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax import custom_derivatives
|
||||
@ -44,8 +46,9 @@ map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
def ravel_first_arg(f, unravel):
|
||||
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
|
||||
def ravel_first_arg(f: Callable, unravel, debug_info: core.DebugInfo):
|
||||
return ravel_first_arg_(lu.wrap_init(f, debug_info=debug_info),
|
||||
unravel).call_wrapped
|
||||
|
||||
@lu.transformation2
|
||||
def ravel_first_arg_(f, unravel, y_flat, *args):
|
||||
@ -179,9 +182,10 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jn
|
||||
return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
|
||||
def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
def _odeint_wrapper(func: Callable, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
y0, unravel = ravel_pytree(y0)
|
||||
func = ravel_first_arg(func, unravel)
|
||||
debug = api_util.debug_info("odeint", func, args, {})
|
||||
func = ravel_first_arg(func, unravel, debug)
|
||||
out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
|
||||
return jax.vmap(unravel)(out)
|
||||
|
||||
|
@ -17,6 +17,7 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import api_util
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.experimental import pjit
|
||||
@ -178,12 +179,13 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
def test_core_call_primitive_inherits_effects(self):
|
||||
|
||||
def f(x):
|
||||
@lu.wrap_init
|
||||
def f_(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return [x]
|
||||
return core.call(f_, x)[0]
|
||||
dbg = api_util.debug_info("test", f_, (2.,), {})
|
||||
return core.call(
|
||||
lu.wrap_init(f_, debug_info=dbg), x)[0]
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
|
||||
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
|
||||
|
@ -15,6 +15,7 @@ import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import api_util
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
@ -85,11 +86,12 @@ class NameStackTest(jtu.JaxTestCase):
|
||||
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
||||
@jax.named_scope('foo')
|
||||
def f(x):
|
||||
@lu.wrap_init
|
||||
@jax.named_scope('bar')
|
||||
def _f(x):
|
||||
return [x + 1]
|
||||
return core.call(_f, x)[0]
|
||||
return core.call(lu.wrap_init(
|
||||
_f,
|
||||
debug_info=api_util.debug_info("test", _f, (0,), {})), x)[0]
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
@ -19,12 +19,13 @@ import functools
|
||||
import itertools
|
||||
import math
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import dtypes
|
||||
@ -61,6 +62,11 @@ jtu.setup_hypothesis(max_examples=50)
|
||||
intx = dtypes.canonicalize_dtype(jnp.int64)
|
||||
floatx = dtypes.canonicalize_dtype(jnp.float64)
|
||||
|
||||
def wrap_init(f: Callable, nr_args: int):
|
||||
# wrapper for lu.wrap_init with debugging info
|
||||
return lu.wrap_init(
|
||||
f,
|
||||
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
|
||||
|
||||
def is_power_of_two(n: int) -> bool:
|
||||
return (n > 0) and (n & (n - 1) == 0)
|
||||
@ -2284,7 +2290,7 @@ class PallasPrimitivesTest(PallasBaseTest):
|
||||
x = pl.load(x_ref, expr())
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
|
||||
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
|
||||
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@parameterized.parameters(*[
|
||||
@ -2299,7 +2305,7 @@ class PallasPrimitivesTest(PallasBaseTest):
|
||||
pl.store(x_ref, expr(), pl.load(x_ref, expr()))
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
|
||||
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
|
||||
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@parameterized.parameters(*[
|
||||
@ -2319,7 +2325,7 @@ class PallasPrimitivesTest(PallasBaseTest):
|
||||
x = pl.swap(x_ref, expr(), pl.load(x_ref, expr()))
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
|
||||
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
|
||||
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import random
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
@ -54,6 +55,12 @@ from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
|
||||
config.parse_flags_with_absl()
|
||||
jtu.setup_hypothesis()
|
||||
|
||||
def wrap_init(f: Callable, nr_args: int):
|
||||
# wrapper for lu.wrap_init with debugging info
|
||||
return lu.wrap_init(
|
||||
f,
|
||||
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
|
||||
|
||||
|
||||
class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
@ -62,7 +69,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def f(x_ref):
|
||||
return [ref_get(x_ref, ())]
|
||||
with self.assertRaises(ValueError):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name="trivial_get", ref_shape=(1, 2),
|
||||
@ -121,10 +128,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
return [out]
|
||||
if should_error:
|
||||
with self.assertRaises(Exception):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
|
||||
else:
|
||||
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval])
|
||||
wrap_init(f, 1), [ref_aval])
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{ReadEffect(len(jaxpr.constvars))})
|
||||
self.assertLen(out_avals, 1)
|
||||
@ -139,7 +146,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def f(x_ref, val):
|
||||
return [ref_swap(x_ref, (), val)]
|
||||
with self.assertRaises(ValueError):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
|
||||
@ -218,10 +225,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
return [out]
|
||||
if should_error:
|
||||
with self.assertRaises(Exception):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
|
||||
else:
|
||||
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval])
|
||||
wrap_init(f, 2), [ref_aval, val_aval])
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{WriteEffect(len(jaxpr.constvars))})
|
||||
self.assertLen(out_avals, 1)
|
||||
@ -295,10 +302,10 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
return []
|
||||
if should_error:
|
||||
with self.assertRaises(Exception):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
|
||||
else:
|
||||
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval])
|
||||
wrap_init(f, 2), [ref_aval, val_aval])
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{AccumEffect(len(jaxpr.constvars))})
|
||||
self.assertLen(out_avals, 0)
|
||||
@ -309,7 +316,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def f(x_ref, val):
|
||||
return [ref_addupdate(x_ref, (), val)]
|
||||
with self.assertRaises(ValueError):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
|
||||
|
||||
def test_can_represent_get_and_swap_in_jaxprs(self):
|
||||
|
||||
@ -318,7 +325,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x[()] = jnp.int32(2)
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, swap_p)
|
||||
@ -331,7 +338,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
ref_addupdate(x, (), jnp.int32(1))
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
|
||||
@ -341,14 +348,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x = x_ref[()]
|
||||
return [x]
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref):
|
||||
x = x_ref[:, 0]
|
||||
return [x]
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((1, 2), jnp.int32)])
|
||||
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_set_custom_pretty_printing_rule(self):
|
||||
@ -356,14 +363,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x_ref[()] = jnp.int32(2)
|
||||
return []
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref, val):
|
||||
x_ref[:, 0] = val
|
||||
return []
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
|
||||
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -372,14 +379,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x = ref_swap(x_ref, (), jnp.int32(2))
|
||||
return [x]
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref, val):
|
||||
x = ref_swap(x_ref, (slice(None), 0), val)
|
||||
return [x]
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
|
||||
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
|
||||
jaxpr.pretty_print(use_color=False))
|
||||
@ -389,7 +396,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
ref_addupdate(x_ref, (), jnp.int32(2))
|
||||
return []
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
|
||||
|
||||
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -397,7 +404,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
ref_addupdate(x_ref, (slice(None), 0), val)
|
||||
return []
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
|
||||
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -413,7 +420,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
|
||||
|
||||
@ -429,7 +436,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
|
||||
@ -449,7 +456,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, get_p)
|
||||
@ -521,12 +528,12 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
# discharge-of-vmap
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals])
|
||||
wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, *idx_avals])
|
||||
wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
in_axes=(ref_bdim, *idx_bdims),
|
||||
@ -544,7 +551,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a = ref_get(a_ref, ())
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||
@ -560,7 +567,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a = ref_get(a_ref, (0, 1))
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||
@ -580,7 +587,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), in_avals)
|
||||
wrap_init(f, 1), in_avals)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(
|
||||
stateful_jaxpr, consts)
|
||||
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
||||
@ -594,7 +601,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return []
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that ignores the first
|
||||
# value and returns second value plus 1.
|
||||
@ -611,7 +618,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||
@ -631,7 +638,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(
|
||||
stateful_jaxpr, consts)
|
||||
@ -648,7 +655,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), in_avals)
|
||||
wrap_init(f, 1), in_avals)
|
||||
|
||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
@ -665,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return []
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that adds the first value,
|
||||
# second value, and 1.
|
||||
@ -683,7 +690,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
@ -704,7 +711,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
jnp.ones((2, 3), 'float32'))
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(
|
||||
stateful_jaxpr, consts)
|
||||
@ -718,7 +725,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
b = a + 1
|
||||
return [a, b]
|
||||
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
@ -738,7 +745,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
shaped_array_ref((4,), jnp.dtype('float32')),
|
||||
shaped_array_ref((4,), jnp.dtype('float32'))
|
||||
]
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = discharge_state(
|
||||
stateful_jaxpr, consts, should_discharge=[False, True])
|
||||
@ -758,7 +765,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return []
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.float32)]
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)
|
||||
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals)
|
||||
|
||||
def test_partial_discharge(self):
|
||||
def f(a_ref, b_ref):
|
||||
@ -769,7 +776,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
scalar_ref_1 = shaped_array_ref((), jnp.float32)
|
||||
scalar_ref_2 = shaped_array_ref((), jnp.float32)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
|
||||
wrap_init(f, 2), [scalar_ref_1, scalar_ref_2])
|
||||
|
||||
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
|
||||
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
|
||||
@ -980,13 +987,14 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals])
|
||||
wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)),
|
||||
[bat_ref_aval, *bat_non_slice_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
|
||||
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, *idx_avals])
|
||||
wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
in_axes=(ref_bdim, *idx_bdims),
|
||||
@ -1025,13 +1033,14 @@ if CAN_USE_HYPOTHESIS:
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
||||
out_axes=[])
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
||||
wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
|
||||
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
|
||||
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
|
||||
wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
||||
@ -1069,13 +1078,14 @@ if CAN_USE_HYPOTHESIS:
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
||||
out_axes=[])
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
||||
wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
|
||||
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
|
||||
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
|
||||
wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
||||
@ -1414,7 +1424,7 @@ class GeneralRefTest(jtu.JaxTestCase):
|
||||
ref_addupdate(x_ref, (), x)
|
||||
return [x]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [AbstractRef(core.AbstractToken())])
|
||||
wrap_init(f, 1), [AbstractRef(core.AbstractToken())])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
|
||||
|
||||
def test_ref_of_ref(self):
|
||||
@ -1423,7 +1433,7 @@ class GeneralRefTest(jtu.JaxTestCase):
|
||||
return [x_ref]
|
||||
# Not sure why you'd ever want to do this, but it works!
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f),
|
||||
wrap_init(f, 1),
|
||||
[AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)
|
||||
self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray)
|
||||
|
@ -17,6 +17,7 @@ import operator
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
@ -62,7 +63,10 @@ class UtilTest(jtu.JaxTestCase):
|
||||
store.store(aux_output)
|
||||
return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):])))
|
||||
|
||||
wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
|
||||
# Wraps `f` as a `WrappedFun`.
|
||||
wf = lu.wrap_init(
|
||||
f,
|
||||
debug_info=api_util.debug_info("test", f, (1, 2), dict(three=3, four=4)))
|
||||
wf, out_thunk = kw_to_positional(wf, 2)
|
||||
# Call the transformed function.
|
||||
scaled_positional, scaled_kwargs = wf.call_wrapped(1, 2, three=3, four=4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user