From 5101184ad46bae25167fd0b7cddb0aa40146ddc9 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 17 Feb 2023 12:45:39 -0800 Subject: [PATCH] Add initial implementation of a `run_state` primitive --- jax/_src/interpreters/partial_eval.py | 97 +++- jax/_src/lax/control_flow/conditionals.py | 17 +- jax/_src/lax/control_flow/for_loop.py | 39 +- jax/_src/state/__init__.py | 4 - jax/_src/state/discharge.py | 474 ++++++++++++++++- jax/_src/state/primitives.py | 9 +- jax/_src/state/utils.py | 54 ++ tests/state_test.py | 588 +++++++++++++++++----- 8 files changed, 1100 insertions(+), 182 deletions(-) create mode 100644 jax/_src/state/utils.py diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 1fbba4856..587bef00b 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -42,6 +42,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent) +from jax._src.state.types import AbstractRef from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, KeyPath, generate_key_paths, keystr) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -1123,9 +1124,36 @@ def partial_eval_jaxpr_custom( ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) if type(ensure_out_inst) is bool: ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) - return _partial_eval_jaxpr_custom_cached( - jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns), - tuple(ensure_out_inst), saveable) + jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ + _partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns), + tuple(in_inst), + tuple(ensure_out_unknowns), + tuple(ensure_out_inst), saveable) + if num_res_ref > 0: + raise ValueError( + "Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.") + return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res + +def partial_eval_jaxpr_stateful( + jaxpr: Jaxpr, + in_unknowns: Sequence[bool], + in_inst: Union[bool, Sequence[bool]], + ensure_out_unknowns: Union[bool, Sequence[bool]], + ensure_out_inst: Union[bool, Sequence[bool]], + saveable: Callable[..., bool], + ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]: + if type(in_inst) is bool: + in_inst = (in_inst,) * len(jaxpr.invars) + if type(ensure_out_unknowns) is bool: + ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) + if type(ensure_out_inst) is bool: + ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) + jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ + _partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns), + tuple(in_inst), + tuple(ensure_out_unknowns), + tuple(ensure_out_inst), saveable) + return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref @weakref_lru_cache def _partial_eval_jaxpr_custom_cached( @@ -1135,9 +1163,10 @@ def _partial_eval_jaxpr_custom_cached( ensure_out_unknowns: Tuple[bool, ...], ensure_out_inst: Tuple[bool, ...], saveable: Callable[..., bool], - ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]: + ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]: env: Dict[Var, Tuple[bool, bool]] = {} residuals: OrderedSet[Var] = OrderedSet() + residual_refs: OrderedSet[Var] = OrderedSet() def read(x: Atom) -> Tuple[bool, bool]: if type(x) is Var: @@ -1162,7 +1191,11 @@ def _partial_eval_jaxpr_custom_cached( if rule: eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn) eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore - residuals.update(res) + for r in res: + if isinstance(r.aval, AbstractRef): + residual_refs.add(r) + else: + residuals.add(r) map(write, unks_out, inst_out, eqn.outvars) elif any(unks_in): inputs = map(ensure_instantiated, inst_in, eqn.invars) @@ -1189,24 +1222,27 @@ def _partial_eval_jaxpr_custom_cached( ins_known, _ = partition_list(in_unknowns, jaxpr.invars) outs_known, _ = partition_list(out_unknowns, jaxpr.outvars) + ref_res_is_input = [r in ins_known for r in residual_refs] + non_input_res_refs, _ = partition_list(ref_res_is_input, list(residual_refs)) + ins_known_and_ref_res = [*ins_known, *non_input_res_refs] known_outvars = [*outs_known, *residuals] - known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known, known_outvars, - known_eqns) - jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, known_outvars, + known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, + known_outvars, known_eqns) + jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns, known_effects) config.jax_enable_checks and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) _, outs_staged = partition_list(out_inst, jaxpr.outvars) - staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns)) - staged_invars = [*residuals, *ins_staged] + staged_invars = [*residuals, *non_input_res_refs, *ins_staged] staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns) + outs_staged, staged_eqns) jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, outs_staged, staged_eqns, staged_effects) config.jax_enable_checks and core.check_jaxpr(jaxpr_staged) - return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals) + return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), + len(non_input_res_refs)) # A primitive rule for policy-driven partial evaluation returns a 5-tuple # with the components representing, respectively: @@ -1283,9 +1319,10 @@ def closed_call_partial_eval_custom_rule( ) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]: # TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule. closed_jaxpr = eqn.params[jaxpr_param_name] - jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \ - partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in, - False, False, saveable) + jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \ + partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in, + False, False, saveable) + num_res = num_res_ref + num_res_out # Forming these fresh ClosedJaxprs defeats caching, but caller handles caching jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts) jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts) @@ -1299,18 +1336,25 @@ def closed_call_partial_eval_custom_rule( params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, params_staged) - residuals = [newvar(res_aval(params_known, a)) - for a in jaxpr_staged.in_avals[:num_res]] - eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], + residuals, ref_residuals = split_list( + [newvar(res_aval(params_known, v)) for v + in jaxpr_staged.in_avals[:num_res]], [num_res_out]) + eqn_known = new_jaxpr_eqn([*ins_known, *ref_residuals], + [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info) - eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, + eqn_staged = new_jaxpr_eqn([*residuals, *ref_residuals, *ins_staged], + out_binders_staged, eqn.primitive, params_staged, jaxpr_staged.effects, eqn.source_info) assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals) + assert len(ins_known) + len(ref_residuals) == len(jaxpr_known.jaxpr.invars) + assert len(ins_staged) + len(ref_residuals) + len(residuals) == len(jaxpr_staged.jaxpr.invars) + assert len(out_binders_known) + len(residuals) == len(jaxpr_known.jaxpr.outvars) new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is Var and not inst] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals + new_vars = [*new_inst, *residuals, *ref_residuals] + return eqn_known, eqn_staged, unks_out, inst_out, new_vars partial_eval_jaxpr_custom_rules[core.call_p] = \ partial(call_partial_eval_custom_rule, 'call_jaxpr', @@ -1536,16 +1580,25 @@ class DynamicJaxprTracer(core.Tracer): api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: - del outvars jaxpr_effects = set() all_vars = [*constvars, *invars] for eqn in eqns: for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): + if eff.input_index >= len(eqn.invars): + raise ValueError( + f"`JaxprInputEffect` {eff} is invalid." + f"\n Equation: {eqn}\n" + "\n Jaxpr: " + f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") invar = eqn.invars[eff.input_index] if invar not in all_vars: raise ValueError( - "`JaxprInputEffect` does not have corresponding input.") + f"`JaxprInputEffect` {eff} does not have " + f"corresponding input: {invar}." + f"\n Equation: {eqn}\n" + "\n Jaxpr: " + f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") eff = eff.replace(input_index=all_vars.index(invar)) jaxpr_effects.add(eff) return jaxpr_effects diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 16b98bf9f..d3df3f54f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -31,7 +31,8 @@ from jax._src import effects from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import util -from jax._src import state +from jax._src.state.discharge import register_discharge_rule, discharge_state +from jax._src.state.types import AbstractRef, RefEffect from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -237,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( (true_fun, false_fun), ops_tree, ops_avals, 'cond') - if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals): + if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals): raise ValueError("Cannot pass `Ref`s into `cond`.") true_jaxpr, false_jaxpr = jaxprs out_tree, false_out_tree = out_trees @@ -338,7 +339,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist - if any(isinstance(eff, state.RefEffect) for branch in branches for eff in + if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): raise NotImplementedError( "State effect not supported in vmap-of-cond.") @@ -426,7 +427,7 @@ def _cond_jvp(primals, tangents, branches, linear): def _cond_partial_eval(trace, *tracers, branches, linear): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns - if any(isinstance(eff, state.RefEffect) for branch in branches for eff in + if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): raise NotImplementedError( "State effect not supported in cond partial-eval.") @@ -703,7 +704,7 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear): linear = [type(x) is ad.UndefinedPrimal for x in ops] in_avals = map(raise_to_shaped, branches[0].in_avals) num_res = len(ops) - sum(linear) - if any(isinstance(eff, state.RefEffect) for branch in branches for eff in + if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): raise NotImplementedError("State effect not supported in cond transpose.") @@ -856,10 +857,10 @@ def _cond_lowering(ctx, index, *args, branches, linear): mlir.register_lowering(cond_p, _cond_lowering) -@state.register_discharge_rule(cond_p) +@register_discharge_rule(cond_p) def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear): discharged_branches = tuple( - core.ClosedJaxpr(state.discharge_state(branch.jaxpr, ())[0], ()) + core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ()) for branch in branches) out_vals = cond_p.bind(*args, branches=discharged_branches, linear=linear) out_ref_vals, out_vals = util.split_list( @@ -868,5 +869,5 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear): new_invals = [] for aval in in_avals: new_invals.append( - next(ref_val_iter) if isinstance(aval, state.AbstractRef) else None) + next(ref_val_iter) if isinstance(aval, AbstractRef) else None) return new_invals, out_vals diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 9a6b1b871..c14ef392a 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -33,7 +33,11 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util -from jax._src import state +from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect) +from jax._src.state import discharge as state_discharge +from jax._src.state import primitives as state_primitives +from jax._src.state import utils as state_utils +from jax._src.state import types as state_types from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip, split_list, split_dict) from jax._src.lax.control_flow import loops @@ -50,15 +54,10 @@ T = TypeVar('T') class Ref(Generic[T]): pass Array = Any -ReadEffect = state.ReadEffect -WriteEffect = state.WriteEffect -AccumEffect = state.AccumEffect -StateEffect = state.StateEffect -AbstractRef = state.AbstractRef -ref_set = state.ref_set -ref_get = state.ref_get -ref_addupdate = state.ref_addupdate -discharge_state = state.discharge_state +ref_set = state_primitives.ref_set +ref_get = state_primitives.ref_get +ref_addupdate = state_primitives.ref_addupdate +discharge_state = state_discharge.discharge_state ## `for_loop` implementation @@ -100,12 +99,6 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, f, state_avals) return jaxpr, consts, out_tree_thunk() -def val_to_ref_aval(x) -> AbstractRef: - aval = core.raise_to_shaped(core.get_aval(x)) - if type(aval) is not core.ShapedArray: - raise Exception(f"can't make ref from {x}") - return AbstractRef(aval) - def for_loop(nsteps: Union[int, Sequence[int]], body: Callable[[Array, Ref[S]], None], init_state: S, *, reverse: bool = False, unroll: int = 1) -> S: @@ -151,14 +144,14 @@ def for_loop(nsteps: Union[int, Sequence[int]], if len(nsteps) > 1: outer_step, *rest_steps = nsteps def wrapped_body(i, refs): - vals = tree_map(lambda ref: state.ref_get(ref, ()), refs) + vals = tree_map(lambda ref: ref_get(ref, ()), refs) vals = for_loop( rest_steps, functools.partial(body, i), vals, unroll=unroll) - tree_map(lambda ref, val: state.ref_set(ref, (), val), refs, vals) + tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals) return for_loop(outer_step, wrapped_body, init_state, unroll=unroll) nsteps, = nsteps flat_state, state_tree = tree_flatten(init_state) - state_avals = map(val_to_ref_aval, flat_state) + state_avals = map(state_utils.val_to_ref_aval, flat_state) idx_aval = core.ShapedArray((), jnp.dtype("int32")) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) @@ -247,7 +240,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], @for_p.def_effectful_abstract_eval def _for_abstract_eval(*avals, jaxpr, **__): # Find out for each of the `Ref`s in our jaxpr what effects they have. - jaxpr_aval_effects = state.get_ref_state_effects( + jaxpr_aval_effects = state_types.get_ref_state_effects( [v.aval for v in jaxpr.invars], jaxpr.effects)[1:] aval_effects = [set(eff.replace(input_index=eff.input_index - 1) for eff in effs) for aval, effs @@ -256,7 +249,7 @@ def _for_abstract_eval(*avals, jaxpr, **__): nonlocal_state_effects = core.join_effects(*aval_effects) return list(avals), nonlocal_state_effects -@state.register_discharge_rule(for_p) +@state_discharge.register_discharge_rule(for_p) def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr, reverse: bool, which_linear: Sequence[bool], nsteps: int, unroll: int @@ -388,7 +381,7 @@ def _is_read_only(ref_effects: Set[StateEffect]) -> bool: def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]: # Get effects for each of the jaxpr inputs and remove the loop index. - ref_effects = state.get_ref_state_effects( + ref_effects = state_types.get_ref_state_effects( [v.aval for v in jaxpr.invars], jaxpr.effects)[1:] # We first assume that *read-only `Ref`s* are loop-invariant. We can safely do # this because the only way something can be loop-varying is if we write to it @@ -771,7 +764,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): Potentially useful for testing and benchmarking. """ flat_state, state_tree = tree_flatten(init_state) - state_avals = map(val_to_ref_aval, flat_state) + state_avals = map(state_utils.val_to_ref_aval, flat_state) idx_aval = core.ShapedArray((), jnp.dtype("int32")) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py index 6f38469e1..fd451fb2a 100644 --- a/jax/_src/state/__init__.py +++ b/jax/_src/state/__init__.py @@ -15,7 +15,3 @@ from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, AccumEffect, StateEffect, RefEffect, get_ref_state_effects, shaped_array_ref) -from jax._src.state.primitives import (ref_get, ref_set, ref_swap, - ref_addupdate, get_p, swap_p, - addupdate_p) -from jax._src.state.discharge import discharge_state, register_discharge_rule diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index ca20dbf33..13ddfebdd 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -15,24 +15,34 @@ from __future__ import annotations import dataclasses from functools import partial +import operator -from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union import numpy as np from jax import lax +from jax._src import api_util +from jax._src import ad_util from jax._src import core from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax._src.state.types import AbstractRef +from jax._src import source_info_util +from jax._src import tree_util +from jax._src.config import config +from jax._src.interpreters import ad +from jax._src.state.types import AbstractRef, RefEffect from jax._src.state.primitives import get_p, swap_p, addupdate_p -from jax._src.util import safe_map, safe_zip, split_list +from jax._src.state.utils import hoist_consts_to_refs +from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache, + partition_list, merge_lists, split_dict) ## JAX utilities map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +PyTreeDef = tree_util.PyTreeDef ## Discharging state @@ -241,3 +251,461 @@ def _closed_call_discharge_rule( else None for aval in in_avals) assert next(ref_vals_iter, None) is None return new_invals, out_vals + +# # `run_state` + +run_state_p = core.Primitive("run_state") +run_state_p.multiple_results = True + +def _run_state_bind(*args: Any, jaxpr: core.Jaxpr, + which_linear: tuple[bool, ...]): + if config.jax_enable_checks: + core.check_jaxpr(jaxpr) + assert len(jaxpr.invars) == len(args) + assert len(which_linear) == len(args) + return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr, + which_linear=which_linear) +run_state_p.def_custom_bind(_run_state_bind) + +def _run_state_impl(*args: Any, jaxpr: core.Jaxpr, + which_linear: tuple[bool, ...]): + del which_linear + discharged_jaxpr, consts = discharge_state(jaxpr, ()) + return core.eval_jaxpr(discharged_jaxpr, consts, *args) +run_state_p.def_impl(_run_state_impl) + +def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr, + which_linear: tuple[bool, ...]): + del which_linear + # When we abstractly evaluate `run_state`, we want to keep track of which + # input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to + # "propagate" out its inner effects. Otherwise, the effects are local to this + # `run_state`. + is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)} + nonlocal_effects = {e for e in jaxpr.effects + if (isinstance(e, RefEffect) and e.input_index in is_ref) + or not isinstance(e, RefEffect)} + return avals, nonlocal_effects +run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval) + +def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, + jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): + nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] + discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) + for _ in range(len(nonzero_tangents)): + _, out_nonzero_tangents = ad.jvp_jaxpr( + core.ClosedJaxpr(discharged_jaxpr, body_consts), + nonzero_tangents, instantiate=nonzero_tangents) + if out_nonzero_tangents == nonzero_tangents: + break + nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) + else: + raise Exception("Invalid fixpoint") + del discharged_jaxpr, body_consts, out_nonzero_tangents + tangents = [ad.instantiate_zeros(t) if inst else t + for t, inst in zip(tangents, nonzero_tangents)] + tangents = [t for t in tangents if type(t) is not ad_util.Zero] + closed_jvp_jaxpr, _ = ad.jvp_jaxpr(core.ClosedJaxpr(jaxpr, ()), + nonzero_tangents, []) + jvp_jaxpr_, jvp_consts = closed_jvp_jaxpr.jaxpr, closed_jvp_jaxpr.consts + jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_) + jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents)) + out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, + which_linear=jvp_which_linear) + out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts), + len(primals)]) + del out_consts + out_tangents_iter = iter(out_tangents) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + for p, nz in zip(out_primals, nonzero_tangents)] + return out_primals, out_tangents +ad.primitive_jvps[run_state_p] = _run_state_jvp + +_save_everything = lambda *_, **__: True + +def _convert_outputs_to_writes( + jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, List[core.ShapedArray]]: + assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." + + in_avals = [v.aval for v in jaxpr.invars] + @lu.wrap_init + def eval_jaxpr(*refs): + # We split the refs into the original input refs and the dummy residual + # refs. + orig_refs, residual_refs = split_list(refs, [len(in_avals)]) + residual_vals = core.eval_jaxpr(jaxpr, (), *orig_refs) + for res_ref, res_val in zip(residual_refs, residual_vals): + res_ref[...] = res_val + return [] + res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef) + else v.aval for v in jaxpr.outvars] + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + eval_jaxpr, [*in_avals, *res_ref_avals]) + assert not consts + return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] + +def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: + assert not jaxpr.constvars, "Jaxpr should not have constvars" + + @lu.wrap_init + def eval_jaxpr(*refs): + residual_refs, orig_refs = split_list(refs, [num_res]) + residual_vals = [r[...] for r in residual_refs] + () = core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs) + return [] + + res_val_avals, orig_ref_avals = \ + split_list([v.aval for v in jaxpr.invars], [num_res]) + res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else + aval for aval in res_val_avals] + jaxpr, _, () = pe.trace_to_jaxpr_dynamic( + eval_jaxpr, [*res_ref_avals, *orig_ref_avals]) + return jaxpr + +def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, + jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): + num_inputs = len(tracers) + assert num_inputs == len(jaxpr.invars) + in_unknowns = [not t.pval.is_known() for t in tracers] + # We first need to run a fixpoint to determine which of the `Ref`s are unknown + # after running the for loop. We want to use the jaxpr to determine which + # `Ref`s are unknown after executing the for loop body given which `Ref`s are + # unknown before. However, the jaxpr has no outputs. Instead, we discharge + # the body and run the fixpoint with the discharged jaxpr. We can do this + # because the outputs of the jaxpr are one-to-one with the inputs. + discharged_jaxpr_, discharged_consts = discharge_state(jaxpr, ()) + discharged_jaxpr = pe.convert_constvars_jaxpr(discharged_jaxpr_) + for _ in range(num_inputs): + jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns + _, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful( + discharged_jaxpr, jaxpr_in_unknowns, jaxpr_in_unknowns, + in_unknowns, False, _save_everything) + # assert out_inst == out_unknowns + out_unknowns = list(out_unknowns) + if out_unknowns == in_unknowns: + break + in_unknowns = map(operator.or_, in_unknowns, out_unknowns) + else: + raise Exception("Invalid fixpoint") + del out_unknowns # redundant since it's the same as `in_unknowns` + tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore + for t, uk in zip(tracers, in_unknowns)) + + # We use `partial_eval_jaxpr_stateful` here because it won't remove effectful + # primitives like `get`/`set`. + jaxpr_known_resout, jaxpr_unknown_resin_, _, _, num_res_out, num_res_ref = \ + pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, in_inst=in_unknowns, + ensure_out_unknowns=[], ensure_out_inst=[], + saveable=_save_everything) + # # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` + # and regular valued input/outputs. However, we'd like to bind these jaxprs to + # a `for`, which expects only `Ref` inputs and no output. We need to convert + # both of these jaxprs into ones that are compatible with `for`. + + # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` + # to output residual values (none of them should be `Ref`s). We'll need to + # convert the output residual values into `Ref`s that are initially empty + # `Ref`s that are written to at the end of the jaxpr. + num_res = num_res_out + num_res_ref + + num_invars = len(jaxpr_known_resout.invars) - num_res_ref + _, res_ref_avals = split_list( + [v.aval for v in jaxpr_known_resout.invars], [num_invars]) + res_avals = [a.inner_aval for a in res_ref_avals] + jaxpr_known, new_res_avals = _convert_outputs_to_writes(jaxpr_known_resout) + # We now run the known jaxpr to obtain our residual values. + known_tracers, _ = partition_list(in_unknowns, tracers) + known_which_linear, _ = partition_list(in_unknowns, which_linear) + known_vals = [t.pval.get_known() for t in known_tracers] + all_res_avals = [*res_avals, *new_res_avals] + empty_res = map(ad_util.zeros_like_aval, all_res_avals) + jaxpr_known_args = [*known_vals, *empty_res] + + jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) + out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, + which_linear=jaxpr_known_which_linear) + known_outputs, residuals = split_list(out_flat, [len(known_tracers)]) + residuals = map(trace.new_instantiated_const, residuals) + ref_res, nonref_res = split_list(residuals, [num_res_ref]) + + # Now we handle the `jaxpr_unknown` that expects residual values as inputs. + # This jaxpr is the output of `partial_eval_jaxpr_stateful` that marks which + # inputs are actually used. + # `partial_eval_jaxpr_stateful` doesn't remove extra inputs/outputs for you + # so we use `dce_jaxpr` here to do that. + # To make it compatible with `for`, we need to convert those residual values + # into `Ref`s. + jaxpr_unknown = _convert_inputs_to_reads(len(new_res_avals), + jaxpr_unknown_resin_) + _, unknown_tracers = partition_list(in_unknowns, tracers) + _, uk_which_linear = partition_list(in_unknowns, which_linear) + unknown_which_linear = (False,) * num_res + tuple(uk_which_linear) + unknown_inputs = [*nonref_res, *ref_res, *unknown_tracers] + # Outputs match inputs so we construct output tracers that look like the input + # tracers. + res_ref_unknown_outputs = [ + pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None) + for t in unknown_inputs] + name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] + source = source_info_util.current().replace(name_stack=name_stack) + + assert len(unknown_inputs) == len(res_ref_unknown_outputs) + assert len(unknown_inputs) == len(jaxpr_unknown.invars) + uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear) + _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], + **uk_params) + eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + run_state_p, uk_params, + eqn_effects, source) + for t in res_ref_unknown_outputs: t.recipe = eqn + _, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res]) + return merge_lists(in_unknowns, known_outputs, unknown_outputs) +pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval + +def _run_state_partial_eval_custom( + saveable: Callable[..., bool], + in_unknowns: Sequence[bool], + in_inst: Sequence[bool], + eqn: core.JaxprEqn): + if not any(in_unknowns): + return eqn, None, in_unknowns, [False] * len(in_unknowns), [] + jaxpr, which_linear = split_dict(eqn.params, ["jaxpr", "which_linear"]) + num_inputs = len(eqn.invars) + # We first need to run a fixpoint to determine which of the `Ref`s are unknown + # after running the for loop. However, the jaxpr has no outputs. Instead, we + # discharge the body and run the fixpoint with the discharged jaxpr. We can do + # this because the outputs of the discharged jaxpr are one-to-one with the + # inputs. + discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) + discharged_jaxpr = discharged_jaxpr.replace( + invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, + constvars=[]) + in_unknowns, in_inst = list(in_unknowns), list(in_inst) + out_unknowns, out_inst = in_unknowns, in_unknowns + for _ in range(num_inputs): + jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns + _, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful( + discharged_jaxpr, + in_unknowns=jaxpr_in_unknowns, + in_inst=jaxpr_in_unknowns, + ensure_out_unknowns=in_unknowns, + ensure_out_inst=in_unknowns, + saveable=saveable) + out_unknowns = list(out_unknowns) + if out_unknowns == in_unknowns: + break + in_unknowns = map(operator.or_, in_unknowns, out_unknowns) + else: + if num_inputs > 0: raise Exception("Invalid fixpoint") + del out_unknowns # Redundant since it's the same as `in_unknowns` + new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst) + if type(x) is core.Var and inst and not already] + + # We use `partial_eval_jaxpr_stateful` here because it won't remove effectful + # primitives like `get`/`set`. + jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res_out, num_res_ref = \ + pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, + in_unknowns, [], [], saveable) + num_res = num_res_ref + num_res_out + # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` and + # non-Ref input/outputs. However, we'd like to bind these jaxprs to a + # `for`, which expects only `Ref` inputs and no output. We need to convert + # both of these jaxprs into ones that are compatible with `for`. + # TODO(sharadmv,mattjj): implement "passthrough" optimization. + + # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` + # to output residual values (none of them should be `Ref`s). We'll need to + # convert the output residual values into `Ref`s that are initially empty + # `Ref`s that are written to at the end of the jaxpr. + jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout) + + # In a stateful partial_eval, the residuals should be `Ref`s. + res_avals = map(AbstractRef, res_avals) # type: ignore + + known_invars, staged_invars = partition_list(in_unknowns, eqn.invars) + known_outvars, staged_outvars = partition_list(in_unknowns, eqn.outvars) + newvar = core.gensym() + _, res_ref_avals = split_list([v.aval for v in jaxpr_known_resout.invars], + [len(known_invars)]) + nonref_resvars = map(newvar, res_avals) + ref_resvars = map(newvar, res_ref_avals) + known_out_resvars = map(newvar, [*res_ref_avals, *res_avals]) + + known_which_linear, _ = partition_list(in_unknowns, which_linear) + jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) + known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars] + + known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear) + _, known_effects = run_state_p.abstract_eval( + *[v.aval for v in known_and_res_invars], **known_params) + eqn_known = pe.new_jaxpr_eqn(known_and_res_invars, + [*known_outvars, *known_out_resvars], + run_state_p, known_params, + known_effects, eqn.source_info) + + jaxpr_staged = _convert_inputs_to_reads(len(res_avals), jaxpr_staged_resin_) + + _, staged_which_linear = partition_list(in_unknowns, which_linear) + which_linear_unknown = (*[False] * num_res, *staged_which_linear) + staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown) + rejiggered_resvars = [*nonref_resvars, *ref_resvars] + _, staged_invars = partition_list(in_unknowns, eqn.invars) + res_staged_invars = [*rejiggered_resvars, *staged_invars] + _, staged_effects = run_state_p.abstract_eval( + *[v.aval for v in res_staged_invars], **staged_params) + _, staged_outvars = partition_list(in_unknowns, eqn.outvars) + if num_res: + @lu.wrap_init + def staged(*args): + out = run_state_p.bind(*args, **staged_params) + return out[num_res:] + staged_call_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(staged, + [v.aval for v in res_staged_invars]) + eqn_staged = pe.new_jaxpr_eqn(res_staged_invars, + staged_outvars, + core.closed_call_p, + dict(call_jaxpr=core.ClosedJaxpr(staged_call_jaxpr, ())), + staged_effects, eqn.source_info) + assert len(res_staged_invars) == len(staged_call_jaxpr.invars) + assert len(staged_outvars) == len(staged_call_jaxpr.outvars) + else: + eqn_staged = pe.new_jaxpr_eqn(staged_invars, + staged_outvars, + run_state_p, + staged_params, + staged_effects, eqn.source_info) + new_vars = [*new_inst, *nonref_resvars, *ref_resvars] + return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars +pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom + +def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool] + ) -> tuple[core.Jaxpr, Any]: + def trans(*args): + # First we want to run the computation to read all the residual refs. We can + # do that by using partial evaluation with all linear inputs unknown. + res_jaxpr_, tangent_jaxpr_, *_, num_res_out, num_res_ref = \ + pe.partial_eval_jaxpr_stateful(jaxpr, which_linear, in_inst=which_linear, + ensure_out_inst=[], + ensure_out_unknowns=[], + saveable=_save_everything) + + num_unknown = sum(which_linear) + num_known = len(jaxpr.invars) - num_unknown + res_args, _ = partition_list(which_linear, args) + res_jaxpr_avals = [v.aval for v in res_jaxpr_.invars] + _, res_avals = split_list(res_jaxpr_avals, [num_known]) + res_avals = [a.inner_aval for a in res_avals] + all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]] + empty_res = map(ad.zeros_like_aval, all_avals) + res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_) + res = run_state_p.bind(*res_args, *empty_res, jaxpr=res_jaxpr, + which_linear=(False,) * (len(res_args) + len(empty_res))) + res = res[len(res_args):] + ref_res_, nonref_res_ = split_list(res, [num_res_ref]) + + # Now that we have residual values, we run the tangent jaxpr. It takes as + # input the residuals, the loop index, and all the refs (at least, the ones + # that are used in the body). Luckily, `tangent_jaxpr_` has all known and + # unknown inputs! + tangent_jaxpr, used_inputs = pe.dce_jaxpr(tangent_jaxpr_, []) + used_res, used_cts = split_list(used_inputs, [len(res)]) + used_nonref_res, used_ref_res = split_list(used_res, [num_res_out]) + _, nonref_res = partition_list(used_nonref_res, nonref_res_) + _, ref_res = partition_list(used_ref_res, ref_res_) + primals_args = [*nonref_res, *ref_res] + _, tangent_args = partition_list(which_linear, args) + _, ct_args = partition_list(used_cts, tangent_args) + ad.backward_pass( + tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ()) + return [] + jaxpr_trans, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) + return jaxpr_trans, consts + +def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, + which_linear: tuple[bool, ...]): + # if any in_ct is nonzero, we definitely want it in args_ (and the + # corresponding x in args could be an undefined primal, but doesnt have to be) + # for non-res stuff: + # getting and setting => (nonzero ct, UndefinedPrimal arg) + # just setting => (nonzero ct, not UndefinedPrimal, dummy value) + # just getting => (zero ct , UndefinedPrimal arg) + # for res stuff: + # (zero ct , not UndefinedPrimal) + assert any(which_linear) + transpose_args = [] + for x, ct in zip(args, in_cts): + if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x): + # this is a residual, take x! + transpose_args.append(x) + elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x): + # the loop was 'just getting', plug in a zero + transpose_args.append(ad_util.zeros_like_aval(x.aval)) + elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x): + # the loop was 'just setting', grab that cotangent! x is dummy + transpose_args.append(ct) + elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x): + # the loop was 'getting and setting', grab that cotangent! + transpose_args.append(ct) + jaxpr_transpose_, consts = _transpose_jaxpr(jaxpr, which_linear) + jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_) + which_linear = (*[False] * len(consts), *which_linear) + const_all_outs = run_state_p.bind(*consts, *transpose_args, + jaxpr=jaxpr_transpose, + which_linear=which_linear) + _, all_outs = split_list(const_all_outs, [len(consts)]) + ct_outs = [ct if ad.is_undefined_primal(x) else None + for x, ct in zip(args, all_outs)] + return ct_outs +ad.primitive_transposes[run_state_p] = _run_state_transpose + +@register_discharge_rule(run_state_p) +def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], + out_avals: Sequence[core.AbstractValue], + *args: Any, jaxpr: core.Jaxpr, + which_linear: Sequence[bool]): + del out_avals + out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear) + new_invals = [] + for aval, out_val in zip(in_avals, out_vals): + new_invals.append(out_val if isinstance(aval, AbstractRef) else None) + return new_invals, out_vals + +def initial_style_jaxpr( + fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue] + ) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]: + return _initial_style_jaxpr(fun, in_tree, tuple(in_avals)) + +@weakref_lru_cache +def _initial_style_jaxpr(fun, in_tree, in_avals): + fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), + tree_util.treedef_tuple((in_tree,))) + debug = pe.debug_info(fun_, in_tree, False, out_tree_thunk, 'run_state') + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) + return jaxpr, consts, out_tree_thunk() + +def run_state(f: Callable[..., None]): + def wrapped(args): + flat_args, in_tree = tree_util.tree_flatten(args) + avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args] + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals)) + jaxpr = hoist_consts_to_refs(jaxpr_) + which_linear = (False,) * (len(consts) + len(flat_args)) + out_const_flat = run_state_p.bind(*consts, *flat_args, jaxpr=jaxpr, + which_linear=which_linear) + _, out_flat = split_list(out_const_flat, [len(consts)]) + return in_tree.unflatten(out_flat) + return wrapped + +def run_state_reference(f: Callable[..., None]): + def wrapped(args): + flat_args, in_tree = tree_util.tree_flatten(args) + avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args] + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals)) + jaxpr = hoist_consts_to_refs(jaxpr_) + discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) + out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, + *consts, *args) + _, out_flat = split_list(out_const_flat, [len(consts)]) + return in_tree.unflatten(out_flat) + return wrapped diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 2a486c032..edf1fcfac 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,8 +18,7 @@ from typing import Any, List, Tuple, Union import numpy as np -from jax import lax - +import jax from jax._src import ad_util from jax._src import core from jax._src import pretty_printer as pp @@ -421,7 +420,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims): # `idxs` doesn't include the non indexed dims. idx_place = [i for i, i_dim in enumerate(indexed_dims) if i_dim].index(ref_dim) - iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) + iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) idxs = tuple_insert(idxs, idx_place, iota) else: bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) @@ -454,7 +453,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims): indexed_dims = tuple_insert(indexed_dims, ref_dim, True) idx_place = [i for i, i_dim in enumerate(indexed_dims) if i_dim].index(ref_dim) - iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) + iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) idxs = tuple_insert(idxs, idx_place, iota) val = batching.moveaxis(val, val_dim, 0) bdim_out = 0 @@ -487,7 +486,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims): idx_place = [i for i, i_dim in enumerate(indexed_dims) if i_dim].index(ref_dim) idxs_shape, = {i.shape for i in idxs} or [()] - iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) + iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) idxs = tuple_insert(idxs, idx_place, iota) val = batching.moveaxis(val, val_dim, 0) return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), [] diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py new file mode 100644 index 000000000..95ef83141 --- /dev/null +++ b/jax/_src/state/utils.py @@ -0,0 +1,54 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for tracing stateful functions.""" + +from jax.interpreters import partial_eval as pe +from jax._src import core +from jax._src import linear_util as lu +from jax._src.state import AbstractRef +from jax._src.util import (partition_list, merge_lists, split_list, safe_map, + safe_zip) +from jax._src.state.primitives import ref_get + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr: + all_const_avals = [var.aval for var in jaxpr.constvars] + is_const_ref = [isinstance(var.aval, AbstractRef) for var in + jaxpr.constvars] + const_avals_, const_ref_avals = partition_list(is_const_ref, all_const_avals) + const_avals = map(AbstractRef, const_avals_) + merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals) + arg_avals = [var.aval for var in jaxpr.invars] + in_avals = [*merged_const_avals, *arg_avals] + num_consts = len(merged_const_avals) + + def _hoist(*consts_args): + all_consts, args = split_list(consts_args, [num_consts]) + consts, const_refs = partition_list(is_const_ref, all_consts) + # We immediately read the const values out of the `Ref`s. + consts = map(lambda x: ref_get(x, ()), consts) + all_consts = merge_lists(is_const_ref, consts, const_refs) + return core.eval_jaxpr(jaxpr, all_consts, *args) + hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_hoist), in_avals) + assert not consts, "All consts should have been converted to refs" + return hoisted_jaxpr + +def val_to_ref_aval(x) -> AbstractRef: + aval = core.raise_to_shaped(core.get_aval(x)) + if type(aval) is not core.ShapedArray: + raise Exception(f"can't make ref from {x}") + return AbstractRef(aval) diff --git a/tests/state_test.py b/tests/state_test.py index f0f9f4c22..ba4660d1a 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -14,14 +14,15 @@ from functools import partial import itertools as it -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union from absl.testing import absltest from absl.testing import parameterized import numpy as np import jax -from jax._src import core +from jax import random from jax import lax +from jax._src import core from jax._src import linear_util as lu from jax.config import config from jax._src.interpreters import partial_eval as pe @@ -38,7 +39,13 @@ try: except (ModuleNotFoundError, ImportError): CAN_USE_HYPOTHESIS = False -from jax._src import state +from jax._src.state.discharge import (run_state, run_state_reference, + discharge_state) +from jax._src.state.primitives import (get_p, swap_p, addupdate_p, + ref_addupdate, ref_get, ref_set, + ref_swap) +from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect, + AccumEffect, AbstractRef) config.parse_flags_with_absl() @@ -46,20 +53,20 @@ class StatePrimitivesTest(jtu.JaxTestCase): def test_cant_eval_get_primitive(self): with self.assertRaises(ValueError): - state.get_p.bind(jnp.ones(5)) + get_p.bind(jnp.ones(5)) def test_cant_eval_swap_primitive(self): with self.assertRaises(ValueError): - state.swap_p.bind(jnp.ones(5), jnp.zeros(5)) + swap_p.bind(jnp.ones(5), jnp.zeros(5)) def test_cant_eval_addupdate_primitive(self): with self.assertRaises(ValueError): - state.addupdate_p.bind(jnp.ones(5), jnp.zeros(5)) + addupdate_p.bind(jnp.ones(5), jnp.zeros(5)) def test_get_abstract_aval_must_take_in_refs(self): ref_aval = core.ShapedArray((), jnp.float32) def f(x_ref): - return [state.ref_get(x_ref, ())] + return [ref_get(x_ref, ())] with self.assertRaises(ValueError): pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval]) @@ -88,9 +95,9 @@ class StatePrimitivesTest(jtu.JaxTestCase): ) def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None, out_dtype=None, should_error=False): - ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) + ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) def f(x_ref): - out = state.ref_get(x_ref, idx) + out = ref_get(x_ref, idx) return [out] if should_error: with self.assertRaises(Exception): @@ -99,7 +106,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval]) self.assertSetEqual(jaxpr.effects, - {state.ReadEffect(len(jaxpr.constvars))}) + {ReadEffect(len(jaxpr.constvars))}) self.assertLen(out_avals, 1) out_aval, = out_avals self.assertIsInstance(out_aval, core.ShapedArray) @@ -110,7 +117,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): ref_aval = core.ShapedArray((), jnp.float32) val_aval = core.ShapedArray((), jnp.float32) def f(x_ref, val): - return [state.ref_swap(x_ref, (), val)] + return [ref_swap(x_ref, (), val)] with self.assertRaises(ValueError): pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) @@ -154,10 +161,10 @@ class StatePrimitivesTest(jtu.JaxTestCase): def test_swap_abstract_eval(self, ref_shape, ref_dtype, val_shape, val_dtype, idx, out_shape=None, out_dtype=None, should_error=False): - ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) + ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) val_aval = core.ShapedArray(val_shape, val_dtype) def f(x_ref, val): - out = state.ref_swap(x_ref, idx, val) + out = ref_swap(x_ref, idx, val) return [out] if should_error: with self.assertRaises(Exception): @@ -166,7 +173,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, - {state.WriteEffect(len(jaxpr.constvars))}) + {WriteEffect(len(jaxpr.constvars))}) self.assertLen(out_avals, 1) out_aval, = out_avals self.assertIsInstance(out_aval, core.ShapedArray) @@ -210,10 +217,10 @@ class StatePrimitivesTest(jtu.JaxTestCase): def test_addupdate_abstract_eval(self, ref_shape, ref_dtype, val_shape, val_dtype, idx, out_shape=None, out_dtype=None, should_error=False): - ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) + ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) val_aval = core.ShapedArray(val_shape, val_dtype) def f(x_ref, val): - state.ref_addupdate(x_ref, idx, val) + ref_addupdate(x_ref, idx, val) return [] if should_error: with self.assertRaises(Exception): @@ -222,14 +229,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, - {state.AccumEffect(len(jaxpr.constvars))}) + {AccumEffect(len(jaxpr.constvars))}) self.assertLen(out_avals, 0) def test_addupdate_abstract_eval_must_take_in_refs(self): ref_aval = core.ShapedArray((), jnp.float32) val_aval = core.ShapedArray((), jnp.float32) def f(x_ref, val): - return [state.ref_addupdate(x_ref, (), val)] + return [ref_addupdate(x_ref, (), val)] with self.assertRaises(ValueError): pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) @@ -240,37 +247,37 @@ class StatePrimitivesTest(jtu.JaxTestCase): x[()] = jnp.int32(2) return (x[()],) jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) - self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p) - self.assertEqual(jaxpr.eqns[1].primitive, state.swap_p) - self.assertEqual(jaxpr.eqns[2].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[0].primitive, swap_p) + self.assertEqual(jaxpr.eqns[1].primitive, swap_p) + self.assertEqual(jaxpr.eqns[2].primitive, get_p) def test_can_represent_addupdate_in_jaxprs(self): def body(x): - state.ref_addupdate(x, (), jnp.int32(1)) + ref_addupdate(x, (), jnp.int32(1)) return (x[()],) jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) - self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p) + self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) def test_get_custom_pretty_printing_rule(self): def body(x_ref): x = x_ref[()] return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) def body(x_ref): x = x_ref[:, 0] return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)]) self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False)) def test_set_custom_pretty_printing_rule(self): @@ -278,48 +285,48 @@ class StatePrimitivesTest(jtu.JaxTestCase): x_ref[()] = jnp.int32(2) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32), + lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False)) def test_swap_custom_pretty_printing_rule(self): def body(x_ref): - x = state.ref_swap(x_ref, (), jnp.int32(2)) + x = ref_swap(x_ref, (), jnp.int32(2)) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): - x = state.ref_swap(x_ref, (slice(None), 0), val) + x = ref_swap(x_ref, (slice(None), 0), val) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32), + lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("c:i32[1], a[:,0] <- a[:,0], b", jaxpr.pretty_print(use_color=False)) def test_addupdate_custom_pretty_printing_rule(self): def body(x_ref): - state.ref_addupdate(x_ref, (), jnp.int32(2)) + ref_addupdate(x_ref, (), jnp.int32(2)) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)]) + lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): - state.ref_addupdate(x_ref, (slice(None), 0), val) + ref_addupdate(x_ref, (slice(None), 0), val) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32), + lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False)) @@ -333,11 +340,11 @@ class StatePrimitivesTest(jtu.JaxTestCase): def g(r, rdot): return jax.jvp(f, (r,), (rdot,)) - in_avals = [state.shaped_array_ref((), jnp.dtype('float32')), - state.shaped_array_ref((), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((), jnp.dtype('float32')), + shaped_array_ref((), jnp.dtype('float32'))] jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) - self.assertEqual(jaxpr.eqns[0].primitive, state.get_p) - self.assertEqual(jaxpr.eqns[1].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[0].primitive, get_p) + self.assertEqual(jaxpr.eqns[1].primitive, get_p) def test_swap_jvp(self): @@ -349,33 +356,33 @@ class StatePrimitivesTest(jtu.JaxTestCase): def g(r, rdot): return jax.jvp(f, (r,), (rdot,)) - in_avals = [state.shaped_array_ref((), jnp.dtype('float32')), - state.shaped_array_ref((), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((), jnp.dtype('float32')), + shaped_array_ref((), jnp.dtype('float32'))] jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) - self.assertEqual(jaxpr.eqns[0].primitive, state.get_p) - self.assertEqual(jaxpr.eqns[1].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[0].primitive, get_p) + self.assertEqual(jaxpr.eqns[1].primitive, get_p) self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p) self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p) - self.assertEqual(jaxpr.eqns[5].primitive, state.swap_p) - self.assertEqual(jaxpr.eqns[6].primitive, state.swap_p) + self.assertEqual(jaxpr.eqns[5].primitive, swap_p) + self.assertEqual(jaxpr.eqns[6].primitive, swap_p) def test_addupdate_jvp(self): def f(a): - state.ref_addupdate(a, (), jnp.float32(1.)) + ref_addupdate(a, (), jnp.float32(1.)) return a[()] def g(r, rdot): return jax.jvp(f, (r,), (rdot,)) - in_avals = [state.shaped_array_ref((), jnp.dtype('float32')), - state.shaped_array_ref((), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((), jnp.dtype('float32')), + shaped_array_ref((), jnp.dtype('float32'))] jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) - self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p) - self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p) - self.assertEqual(jaxpr.eqns[2].primitive, state.get_p) - self.assertEqual(jaxpr.eqns[3].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) + self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p) + self.assertEqual(jaxpr.eqns[2].primitive, get_p) + self.assertEqual(jaxpr.eqns[3].primitive, get_p) @jtu.sample_product( [dict(ref_shape=ref_shape, ref_bdim=ref_bdim, idx_shape=idx_shape, @@ -392,11 +399,11 @@ class StatePrimitivesTest(jtu.JaxTestCase): op=[ lambda x_ref, indexer: [x_ref[indexer]], lambda x_ref, indexer: [ - state.ref_swap(x_ref, indexer, + ref_swap(x_ref, indexer, jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)])], lambda x_ref, indexer: ( - state.ref_addupdate(x_ref, indexer, + ref_addupdate(x_ref, indexer, jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]) or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]]) @@ -420,8 +427,8 @@ class StatePrimitivesTest(jtu.JaxTestCase): return tuple_insert(shape, idx, axis_size) batched_ref_shape = maybe_insert(ref_shape, ref_bdim) - ref_aval = state.shaped_array_ref(ref_shape, float_) - bat_ref_aval = state.shaped_array_ref(batched_ref_shape, float_) + ref_aval = shaped_array_ref(ref_shape, float_) + bat_ref_aval = shaped_array_ref(batched_ref_shape, float_) idx_avals = [core.ShapedArray(idx_shape, int_) for _ in idx_bdims] @@ -444,12 +451,12 @@ class StatePrimitivesTest(jtu.JaxTestCase): f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals]) - jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) # vmap-of-discharge stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, *idx_avals]) - jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim, ref_bdim]) @@ -463,13 +470,13 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_get(self): def f(a_ref): - a = state.ref_get(a_ref, ()) + a = ref_get(a_ref, ()) return [a + 1] - in_avals = [state.shaped_array_ref((), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. - discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.outvars, 2) self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p) @@ -479,13 +486,13 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_get_with_slice(self): def f(a_ref): - a = state.ref_get(a_ref, (0, 1)) + a = ref_get(a_ref, (0, 1)) return [a + 1] - in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. - discharged_jaxpr, () = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.outvars, 2) self.assertIn(lax.dynamic_slice_p, @@ -500,10 +507,10 @@ class StateDischargeTest(jtu.JaxTestCase): def f(a_ref): a = a_ref[jnp.array([0, 1])] return [a + 1] - in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), in_avals) - discharged_jaxpr, discharged_consts = state.discharge_state( + discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3)) outval, refval = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval) @@ -512,15 +519,15 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_set(self): def f(a_ref, b): - state.ref_set(a_ref, (), b + 1) + ref_set(a_ref, (), b + 1) return [] - in_avals = [state.shaped_array_ref((), jnp.dtype('float32')), + in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that ignores the first # value and returns second value plus 1. - discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 2) self.assertLen(discharged_jaxpr.outvars, 1) self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.), @@ -530,13 +537,13 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_set_with_slice(self): def f(a_ref): - state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) + ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) return [] - in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. - discharged_jaxpr, () = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.outvars, 1) self.assertIn(lax.dynamic_update_slice_p, @@ -552,10 +559,10 @@ class StateDischargeTest(jtu.JaxTestCase): def f(a_ref): a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32') return [] - in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) - discharged_jaxpr, discharged_consts = state.discharge_state( + discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3)) refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval) @@ -563,15 +570,15 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_addupdate(self): def f(a_ref, b): - state.ref_addupdate(a_ref, (), b + 1) + ref_addupdate(a_ref, (), b + 1) return [] - in_avals = [state.shaped_array_ref((), jnp.dtype('float32')), + in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that adds the first value, # second value, and 1. - discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 2) self.assertLen(discharged_jaxpr.outvars, 1) self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.), @@ -581,13 +588,13 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_addupdate_with_slice(self): def f(a_ref): - state.ref_addupdate(a_ref, (0, 1), + ref_addupdate(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) return [] - in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) - discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.outvars, 1) self.assertIn(lax.dynamic_update_slice_p, @@ -602,13 +609,13 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_addupdate_with_gather(self): def f(a_ref): - state.ref_addupdate(a_ref, (jnp.array([0, 1]),), + ref_addupdate(a_ref, (jnp.array([0, 1]),), jnp.ones((2, 3), 'float32')) return [] - in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) - discharged_jaxpr, discharged_consts = state.discharge_state( + discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3)) refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval) @@ -616,13 +623,13 @@ class StateDischargeTest(jtu.JaxTestCase): def test_discharge_jaxpr_with_multiple_outputs(self): def f(a_ref): - a = state.ref_get(a_ref, ()) + a = ref_get(a_ref, ()) b = a + 1 return [a, b] - in_avals = [state.shaped_array_ref((4,), jnp.dtype('float32'))] + in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) - discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) self.assertLen(discharged_jaxpr.outvars, 3) inval = jnp.arange(4., dtype=jnp.float32) @@ -633,33 +640,33 @@ class StateDischargeTest(jtu.JaxTestCase): def test_partially_discharging_jaxpr_keeps_refs(self): def f(a_ref, b_ref): - state.ref_set(a_ref, (), jnp.ones(4, jnp.float32)) - state.ref_set(b_ref, (), jnp.ones(4, jnp.float32)) + ref_set(a_ref, (), jnp.ones(4, jnp.float32)) + ref_set(b_ref, (), jnp.ones(4, jnp.float32)) return [] in_avals = [ - state.shaped_array_ref((4,), jnp.dtype('float32')), - state.shaped_array_ref((4,), jnp.dtype('float32')) + shaped_array_ref((4,), jnp.dtype('float32')), + shaped_array_ref((4,), jnp.dtype('float32')) ] stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) - discharged_jaxpr, _ = state.discharge_state( + discharged_jaxpr, _ = discharge_state( stateful_jaxpr, consts, should_discharge=[False, True]) self.assertLen(discharged_jaxpr.invars, 2) self.assertLen(discharged_jaxpr.outvars, 1) - self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef) + self.assertIsInstance(discharged_jaxpr.invars[0].aval, AbstractRef) self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray) self.assertEqual(discharged_jaxpr.effects, - {state.WriteEffect(len(discharged_jaxpr.constvars))}) + {WriteEffect(len(discharged_jaxpr.constvars))}) def test_ellipsis_index(self): def f(ref): - state.ref_set(ref, ..., jnp.array(0., dtype=jnp.float32)) - state.ref_get(ref, ...) + ref_set(ref, ..., jnp.array(0., dtype=jnp.float32)) + ref_get(ref, ...) ref[...] = jnp.array(0., dtype=jnp.float32) ref[...] return [] - in_avals = [state.shaped_array_ref((), jnp.float32)] + in_avals = [shaped_array_ref((), jnp.float32)] pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) @@ -672,7 +679,7 @@ if CAN_USE_HYPOTHESIS: Shape = tuple[int, ...] class IndexParam(NamedTuple): - ref_aval: state.shaped_array_ref + ref_aval: shaped_array_ref ref_shape: Shape indexed_dims: list[bool] idx_avals: tuple[core.ShapedArray, ...] @@ -692,7 +699,7 @@ if CAN_USE_HYPOTHESIS: slice_shape = (*idx_shape, *sliced_shape) else: slice_shape = ref_shape - ref_aval = state.shaped_array_ref(ref_shape, np.float32) + ref_aval = shaped_array_ref(ref_shape, np.float32) idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in range(sum(indexed_dims))) slice_aval = core.ShapedArray(slice_shape, np.float32) @@ -704,7 +711,7 @@ if CAN_USE_HYPOTHESIS: ref_bdim: Optional[int] non_slice_idx_bdims: tuple[Optional[int], ...] slice_bdim: int - bat_ref_aval: state.shaped_array_ref + bat_ref_aval: shaped_array_ref bat_ref_shape: Shape bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] bat_non_slice_idx_shapes: tuple[Shape, ...] @@ -752,7 +759,7 @@ if CAN_USE_HYPOTHESIS: min_value=0, max_value=len(index_param.slice_shape))) bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) - bat_ref_aval = state.shaped_array_ref(bat_ref_shape, np.float32) + bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) bat_non_slice_idx_avals = tuple( core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) @@ -829,7 +836,7 @@ if CAN_USE_HYPOTHESIS: def f(ref, *non_slice_idx): idx = _pack_idx(non_slice_idx, indexed_dims) - return [state.ref_get(ref, idx)] + return [ref_get(ref, idx)] ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals @@ -843,13 +850,13 @@ if CAN_USE_HYPOTHESIS: f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) # vmap-of-discharge stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, *idx_avals]) - jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim, ref_bdim]) @@ -867,7 +874,7 @@ if CAN_USE_HYPOTHESIS: def f(ref, val, *non_slice_idx): idx = _pack_idx(non_slice_idx, indexed_dims) - state.ref_set(ref, idx, val) + ref_set(ref, idx, val) return [] ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval @@ -886,13 +893,13 @@ if CAN_USE_HYPOTHESIS: out_axes=[]) stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) # vmap-of-discharge stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, val_bdim, *idx_bdims), out_axes=[ref_bdim]) @@ -910,7 +917,7 @@ if CAN_USE_HYPOTHESIS: def f(ref, val, *non_slice_idx): idx = _pack_idx(non_slice_idx, indexed_dims) - state.ref_addupdate(ref, idx, val) + ref_addupdate(ref, idx, val) return [] ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval @@ -929,13 +936,13 @@ if CAN_USE_HYPOTHESIS: out_axes=[]) stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) # vmap-of-discharge stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), in_axes=(ref_bdim, val_bdim, *idx_bdims), out_axes=[ref_bdim]) @@ -1031,10 +1038,10 @@ class GeneralRefTest(jtu.JaxTestCase): def f(x_ref): x = x_ref[...] x_ref[...] = x - state.ref_addupdate(x_ref, (), x) + ref_addupdate(x_ref, (), x) return [x] jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [state.AbstractRef(core.UnshapedArray(jnp.int32))]) + lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))]) self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray) self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32")) @@ -1042,10 +1049,10 @@ class GeneralRefTest(jtu.JaxTestCase): def f(x_ref): x = x_ref[...] x_ref[...] = x - state.ref_addupdate(x_ref, (), x) + ref_addupdate(x_ref, (), x) return [x] jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [state.AbstractRef(core.AbstractToken())]) + lu.wrap_init(f), [AbstractRef(core.AbstractToken())]) self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken) def test_ref_of_ref(self): @@ -1055,9 +1062,356 @@ class GeneralRefTest(jtu.JaxTestCase): # Not sure why you'd ever want to do this, but it works! jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), - [state.AbstractRef(state.AbstractRef(core.ShapedArray((), jnp.int32)))]) - self.assertIs(type(jaxpr.outvars[0].aval), state.AbstractRef) + [AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))]) + self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef) self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray) + +class RunStateTest(jtu.JaxTestCase): + + def test_simple_run_state(self): + out = run_state(lambda _: None)(1) + self.assertEqual(out, 1) + + def test_nontrivial_run_state(self): + def f(refs): + x_ref, y_ref = refs + x = x_ref[...] * y_ref[...] + y_ref[...] = x * 2 + x_ref[...] = y_ref[...] + x_ref[...] + # x + x * y * 2, x * y * 2 + x, y = run_state(f)((2, 3)) + self.assertEqual(x, 2 + 2 * 3 * 2) + self.assertEqual(y, 2 * 3 * 2) + + def test_simple_run_state_with_multiple_refs(self): + out1, out2 = run_state(lambda _: None)((1, 2)) + self.assertEqual(out1, 1) + self.assertEqual(out2, 2) + + def test_simple_run_state_with_tuple(self): + out1, out2 = run_state(lambda _: None)((1, 2)) + self.assertEqual(out1, 1) + self.assertEqual(out2, 2) + + def test_can_stage_run_state(self): + def f(x): + return run_state(lambda _: None)(x) + _ = jax.make_jaxpr(f)(2) + + def test_nested_run_state_captures_effects(self): + def f(x): + def body(x_ref): + def inner(y_ref): + y_ref[...] + x_ref[...] + run_state(inner)(1) + return run_state(body)(x) + jaxpr = jax.make_jaxpr(f)(2) + self.assertEmpty(jaxpr.effects) + self.assertEmpty(jaxpr.jaxpr.eqns[0].effects) + self.assertSetEqual(jaxpr.jaxpr.eqns[0].params["jaxpr"].effects, + {ReadEffect(0)}) + self.assertSetEqual( + jaxpr.jaxpr.eqns[0].params["jaxpr"].eqns[0].params["jaxpr"].effects, + {ReadEffect(0), ReadEffect(1)}) + + def test_jvp_of_run_state(self): + @run_state + def f(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + xy, xy_t = jax.jvp(f, ((2., 1.),), ((3., 1.),)) + # x, sin(x) + self.assertAllClose(xy, (2., np.sin(2.))) + # t, cos(x) * t + self.assertAllClose(xy_t, (3., 3 * np.cos(2.))) + + x, x_t = jax.jvp(lambda x: f((x, 0.))[1], (2.,), (3.,)) + self.assertAllClose(x, np.sin(2.)) + self.assertAllClose(x_t, 3 * np.cos(2.)) + + def test_jvp_of_run_state_with_zero_tangent(self): + @run_state + def f(refs): + x_ref, z_ref, y_ref = refs + del z_ref + y_ref[...] = jnp.sin(x_ref[...]) + x, x_t = jax.jvp(lambda x: f((x, 0., 0.,))[2], (2.,), (3.,)) + self.assertAllClose(x, np.sin(2.)) + self.assertAllClose(x_t, 3 * np.cos(2.)) + + def test_linearize_of_run_state(self): + @run_state + def f(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + + (x, y), f_lin = jax.linearize(f, (1., 0.)) + self.assertAllClose(x, 1.) + self.assertAllClose(y, np.sin(1.)) + x_t, y_t = f_lin((2., 1.)) + self.assertAllClose(x_t, 2.) + self.assertAllClose(y_t, 2. * np.cos(1.)) + + def test_grad_of_run_state(self): + @run_state + def f(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + + def sin(x): + return f((x, 0.))[1] + + x_g = jax.grad(sin)(1.) + self.assertAllClose(x_g, np.cos(1.)) + + x_g2 = jax.grad(jax.grad(sin))(1.) + self.assertAllClose(x_g2, -np.sin(1.)) + + x_g3 = jax.grad(jax.grad(jax.grad(sin)))(1.) + self.assertAllClose(x_g3, -np.cos(1.)) + + def test_vjp_of_run_state(self): + @run_state + def f(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + + (x, y), f_vjp = jax.vjp(f, (1., 0.)) + self.assertAllClose(x, 1.) + self.assertAllClose(y, np.sin(1.)) + ((x_ct, y_ct),) = f_vjp((0., 1.)) + self.assertAllClose(x_ct, np.cos(1.)) + self.assertAllClose(y_ct, 0.) + + def test_vjp_of_run_state_single(self): + @run_state + def f(x_ref): + x = x_ref[...] + def _body(ref): + ref[...] = jnp.sin(ref[...]) + x = run_state(_body)(x) + x_ref[...] = x + + y, f_lin = jax.linearize(f, 1.) + self.assertAllClose(y, np.sin(1.)) + y_t = f_lin(1.) + self.assertAllClose(y_t, np.cos(1.)) + + y, f_vjp = jax.vjp(f, 1.) + self.assertAllClose(y, np.sin(1.)) + x_ct, = f_vjp(1.) + self.assertAllClose(x_ct, np.cos(1.)) + + jtu.check_grads(f, (0.5,), order=3) + +if CAN_USE_HYPOTHESIS: + + class FuncSpec(NamedTuple): + fun: Callable[..., Any] + name: str + min_rank: int = 0 + max_rank: int = 4 + min_dim: int = 0 + max_dim: int = 4 + + def call(self, *args): + return run_state(self.fun)(*args) + + def ref(self, *args): + return run_state_reference(self.fun)(*args) + + def sin_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + + sin_spec = FuncSpec(sin_stateful, "sin") + + def cos_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.cos(x_ref[...]) + + cos_spec = FuncSpec(cos_stateful, "cos") + + def mul2_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = x_ref[...] + y_ref[...] = y_ref[...] + x_ref[...] + + mul2_spec = FuncSpec(mul2_stateful, "mul2") + + def mul2_stateful_with_constant(refs): + x_ref, y_ref = refs + y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] + + mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") + + def crazy_identity_stateful(refs): + x_ref, y_ref = refs + x = x_ref[...] + x_ref[...] = (x + x) / 2 + y_ref[...] = x_ref[...] + y = y_ref[...] + y_ref[...] = (y + y) / 2 + + crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") + + def func_spec(depth: int = 4): + raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, + mul2_constant_spec, crazy_identity_spec]) + if depth > 0: + return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), + compose_spec(depth - 1)]) + return raw_specs + + @hps.composite + def compose_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(*args): + f1.fun(*args) + f2.fun(*args) + return FuncSpec(wrapped_impl, + f"({f2.name} . {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + + @hps.composite + def nest_spec(draw, depth): + f = draw(func_spec(depth)) + def wrapped_impl(refs): + x_ref, y_ref = refs + x, y = x_ref[...], y_ref[...] + x, y = run_state(f.fun)((x, y)) + x_ref[...], y_ref[...] = x, y + return FuncSpec(wrapped_impl, + f"nest({f.name})", + min_rank=f.min_rank, + max_rank=f.max_rank, + min_dim=f.min_dim, + max_dim=f.max_dim) + + + @hps.composite + def add_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(refs): + x_ref, y_ref = refs + x, y = x_ref[...], y_ref[...] + x1, y1 = run_state(f1.fun)((x, y)) + x2, y2 = run_state(f2.fun)((x, y)) + x_ref[...], y_ref[...] = x1 + x2, y1 + y2 + return FuncSpec(wrapped_impl, + f"({f2.name} + {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + + class RunStateHypothesisTest(jtu.JaxTestCase): + + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=config.FLAGS.jax_num_generated_cases) + def test_jvp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + t = random.normal(k2, x.shape) + y, y_t = jax.jvp(impl, (x,), (t,)) + y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) + self.assertAllClose(y, y_ref) + self.assertAllClose(y_t, y_ref_t) + + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=config.FLAGS.jax_num_generated_cases) + def test_linearize(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) + + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=config.FLAGS.jax_num_generated_cases) + def test_vjp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + key, k1, k2 = random.split(random.PRNGKey(0), 3) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + + # First order + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t)) + + y, impl_vjp = jax.vjp(impl, x) + y_ref, ref_vjp = jax.vjp(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(k2, x.shape) + y2 = random.normal(k1, y.shape) + self.assertAllClose(impl_vjp(t), ref_vjp(t)) + + # Second order + key, k1, k2 = random.split(key, 3) + t2 = random.normal(k2, t.shape) + + (x,), impl_lin2 = jax.linearize(impl_vjp, t2) + (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(k1, y.shape) + self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) + + (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) + (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(k1, y.shape) + self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())