Merge pull request #15149 from sharadmv:runstate

PiperOrigin-RevId: 521809360
This commit is contained in:
jax authors 2023-04-04 10:56:25 -07:00
commit 3c1f3abba2
11 changed files with 1140 additions and 194 deletions

View File

@ -364,6 +364,7 @@ pytype_strict_library(
":effects",
":profiler",
":source_info_util",
":state_types",
":tree_util",
":util",
] + py_deps("numpy"),
@ -411,6 +412,20 @@ pytype_strict_library(
],
)
pytype_strict_library(
name = "state_types",
srcs = [
"_src/state/__init__.py",
"_src/state/types.py",
],
deps = [
":core",
":effects",
":pretty_printer",
":util",
],
)
pytype_strict_library(
name = "tree_util",
srcs = ["_src/tree_util.py"],

View File

@ -42,6 +42,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent)
from jax._src.state.types import AbstractRef
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
KeyPath, generate_key_paths, keystr)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -1123,9 +1124,36 @@ def partial_eval_jaxpr_custom(
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
return _partial_eval_jaxpr_custom_cached(
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
if num_res_ref > 0:
raise ValueError(
"Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res
def partial_eval_jaxpr_stateful(
jaxpr: Jaxpr,
in_unknowns: Sequence[bool],
in_inst: Union[bool, Sequence[bool]],
ensure_out_unknowns: Union[bool, Sequence[bool]],
ensure_out_inst: Union[bool, Sequence[bool]],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
if type(ensure_out_unknowns) is bool:
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref
@weakref_lru_cache
def _partial_eval_jaxpr_custom_cached(
@ -1135,9 +1163,10 @@ def _partial_eval_jaxpr_custom_cached(
ensure_out_unknowns: Tuple[bool, ...],
ensure_out_inst: Tuple[bool, ...],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
env: Dict[Var, Tuple[bool, bool]] = {}
residuals: OrderedSet[Var] = OrderedSet()
residual_refs: OrderedSet[Var] = OrderedSet()
def read(x: Atom) -> Tuple[bool, bool]:
if type(x) is Var:
@ -1162,7 +1191,11 @@ def _partial_eval_jaxpr_custom_cached(
if rule:
eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn)
eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore
residuals.update(res)
for r in res:
if isinstance(r.aval, AbstractRef):
residual_refs.add(r)
else:
residuals.add(r)
map(write, unks_out, inst_out, eqn.outvars)
elif any(unks_in):
inputs = map(ensure_instantiated, inst_in, eqn.invars)
@ -1189,24 +1222,27 @@ def _partial_eval_jaxpr_custom_cached(
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
ref_res_is_input = [r in ins_known for r in residual_refs]
non_input_res_refs, _ = partition_list(ref_res_is_input, list(residual_refs))
ins_known_and_ref_res = [*ins_known, *non_input_res_refs]
known_outvars = [*outs_known, *residuals]
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known, known_outvars,
known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, known_outvars,
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
known_outvars, known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
known_eqns, known_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
_, ins_staged = partition_list(in_inst, jaxpr.invars)
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
staged_invars = [*residuals, *ins_staged]
staged_invars = [*residuals, *non_input_res_refs, *ins_staged]
staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns)
outs_staged, staged_eqns)
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns, staged_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
len(non_input_res_refs))
# A primitive rule for policy-driven partial evaluation returns a 5-tuple
# with the components representing, respectively:
@ -1283,9 +1319,10 @@ def closed_call_partial_eval_custom_rule(
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
closed_jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in,
False, False, saveable)
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
False, False, saveable)
num_res = num_res_ref + num_res_out
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
@ -1299,18 +1336,25 @@ def closed_call_partial_eval_custom_rule(
params_known, params_staged = params_updater(
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
params_staged)
residuals = [newvar(res_aval(params_known, a))
for a in jaxpr_staged.in_avals[:num_res]]
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
residuals, ref_residuals = split_list(
[newvar(res_aval(params_known, v)) for v
in jaxpr_staged.in_avals[:num_res]], [num_res_out])
eqn_known = new_jaxpr_eqn([*ins_known, *ref_residuals],
[*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
eqn_staged = new_jaxpr_eqn([*residuals, *ref_residuals, *ins_staged],
out_binders_staged,
eqn.primitive, params_staged, jaxpr_staged.effects,
eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
assert len(ins_known) + len(ref_residuals) == len(jaxpr_known.jaxpr.invars)
assert len(ins_staged) + len(ref_residuals) + len(residuals) == len(jaxpr_staged.jaxpr.invars)
assert len(out_binders_known) + len(residuals) == len(jaxpr_known.jaxpr.outvars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is Var and not inst]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
new_vars = [*new_inst, *residuals, *ref_residuals]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
@ -1536,16 +1580,25 @@ class DynamicJaxprTracer(core.Tracer):
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
del outvars
jaxpr_effects = set()
all_vars = [*constvars, *invars]
for eqn in eqns:
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
if eff.input_index >= len(eqn.invars):
raise ValueError(
f"`JaxprInputEffect` {eff} is invalid."
f"\n Equation: {eqn}\n"
"\n Jaxpr: "
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
invar = eqn.invars[eff.input_index]
if invar not in all_vars:
raise ValueError(
"`JaxprInputEffect` does not have corresponding input.")
f"`JaxprInputEffect` {eff} does not have "
f"corresponding input: {invar}."
f"\n Equation: {eqn}\n"
"\n Jaxpr: "
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
eff = eff.replace(input_index=all_vars.index(invar))
jaxpr_effects.add(eff)
return jaxpr_effects

View File

@ -31,7 +31,8 @@ from jax._src import effects
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import util
from jax._src import state
from jax._src.state.discharge import register_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -237,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
raise ValueError("Cannot pass `Ref`s into `cond`.")
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
@ -338,7 +339,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
index, *ops = args
index_dim, *op_dims = dims
# TODO(sharadmv): clean this up by adding a specific blocklist
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
if any(isinstance(eff, RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
"State effect not supported in vmap-of-cond.")
@ -426,7 +427,7 @@ def _cond_jvp(primals, tangents, branches, linear):
def _cond_partial_eval(trace, *tracers, branches, linear):
in_unknowns = [t.pval[0] is not None for t in tracers]
index_uk, *ops_uk = in_unknowns
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
if any(isinstance(eff, RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
"State effect not supported in cond partial-eval.")
@ -703,7 +704,7 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
linear = [type(x) is ad.UndefinedPrimal for x in ops]
in_avals = map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
if any(isinstance(eff, RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError("State effect not supported in cond transpose.")
@ -856,10 +857,10 @@ def _cond_lowering(ctx, index, *args, branches, linear):
mlir.register_lowering(cond_p, _cond_lowering)
@state.register_discharge_rule(cond_p)
@register_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
discharged_branches = tuple(
core.ClosedJaxpr(state.discharge_state(branch.jaxpr, ())[0], ())
core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ())
for branch in branches)
out_vals = cond_p.bind(*args, branches=discharged_branches, linear=linear)
out_ref_vals, out_vals = util.split_list(
@ -868,5 +869,5 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, state.AbstractRef) else None)
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
return new_invals, out_vals

View File

@ -33,7 +33,11 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import state
from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect)
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
from jax._src.state import utils as state_utils
from jax._src.state import types as state_types
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict)
from jax._src.lax.control_flow import loops
@ -50,15 +54,10 @@ T = TypeVar('T')
class Ref(Generic[T]): pass
Array = Any
ReadEffect = state.ReadEffect
WriteEffect = state.WriteEffect
AccumEffect = state.AccumEffect
StateEffect = state.StateEffect
AbstractRef = state.AbstractRef
ref_set = state.ref_set
ref_get = state.ref_get
ref_addupdate = state.ref_addupdate
discharge_state = state.discharge_state
ref_set = state_primitives.ref_set
ref_get = state_primitives.ref_get
ref_addupdate = state_primitives.ref_addupdate
discharge_state = state_discharge.discharge_state
## `for_loop` implementation
@ -100,12 +99,6 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
f, state_avals)
return jaxpr, consts, out_tree_thunk()
def val_to_ref_aval(x) -> AbstractRef:
aval = core.raise_to_shaped(core.get_aval(x))
if type(aval) is not core.ShapedArray:
raise Exception(f"can't make ref from {x}")
return AbstractRef(aval)
def for_loop(nsteps: Union[int, Sequence[int]],
body: Callable[[Array, Ref[S]], None], init_state: S,
*, reverse: bool = False, unroll: int = 1) -> S:
@ -151,14 +144,14 @@ def for_loop(nsteps: Union[int, Sequence[int]],
if len(nsteps) > 1:
outer_step, *rest_steps = nsteps
def wrapped_body(i, refs):
vals = tree_map(lambda ref: state.ref_get(ref, ()), refs)
vals = tree_map(lambda ref: ref_get(ref, ()), refs)
vals = for_loop(
rest_steps, functools.partial(body, i), vals, unroll=unroll)
tree_map(lambda ref, val: state.ref_set(ref, (), val), refs, vals)
tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals)
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
nsteps, = nsteps
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(val_to_ref_aval, flat_state)
state_avals = map(state_utils.val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])
@ -247,7 +240,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
@for_p.def_effectful_abstract_eval
def _for_abstract_eval(*avals, jaxpr, **__):
# Find out for each of the `Ref`s in our jaxpr what effects they have.
jaxpr_aval_effects = state.get_ref_state_effects(
jaxpr_aval_effects = state_types.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
aval_effects = [set(eff.replace(input_index=eff.input_index - 1)
for eff in effs) for aval, effs
@ -256,7 +249,7 @@ def _for_abstract_eval(*avals, jaxpr, **__):
nonlocal_state_effects = core.join_effects(*aval_effects)
return list(avals), nonlocal_state_effects
@state.register_discharge_rule(for_p)
@state_discharge.register_discharge_rule(for_p)
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
reverse: bool, which_linear: Sequence[bool],
nsteps: int, unroll: int
@ -388,7 +381,7 @@ def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
# Get effects for each of the jaxpr inputs and remove the loop index.
ref_effects = state.get_ref_state_effects(
ref_effects = state_types.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
# We first assume that *read-only `Ref`s* are loop-invariant. We can safely do
# this because the only way something can be loop-varying is if we write to it
@ -771,7 +764,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
Potentially useful for testing and benchmarking.
"""
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(val_to_ref_aval, flat_state)
state_avals = map(state_utils.val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])

View File

@ -15,7 +15,3 @@
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect, StateEffect, RefEffect,
get_ref_state_effects, shaped_array_ref)
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
ref_addupdate, get_p, swap_p,
addupdate_p)
from jax._src.state.discharge import discharge_state, register_discharge_rule

View File

@ -15,24 +15,34 @@
from __future__ import annotations
import dataclasses
from functools import partial
import operator
from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union
import numpy as np
from jax import lax
from jax._src import api_util
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src.state.types import AbstractRef
from jax._src import source_info_util
from jax._src import tree_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.util import safe_map, safe_zip, split_list
from jax._src.state.utils import hoist_consts_to_refs
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
partition_list, merge_lists, split_dict)
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
PyTreeDef = tree_util.PyTreeDef
## Discharging state
@ -241,3 +251,462 @@ def _closed_call_discharge_rule(
else None for aval in in_avals)
assert next(ref_vals_iter, None) is None
return new_invals, out_vals
# # `run_state`
run_state_p = core.Primitive("run_state")
run_state_p.multiple_results = True
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...]):
if config.jax_enable_checks:
core.check_jaxpr(jaxpr)
assert len(jaxpr.invars) == len(args)
assert len(which_linear) == len(args)
return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr,
which_linear=which_linear)
run_state_p.def_custom_bind(_run_state_bind)
def _run_state_impl(*args: Any, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...]):
del which_linear
discharged_jaxpr, consts = discharge_state(jaxpr, ())
return core.eval_jaxpr(discharged_jaxpr, consts, *args)
run_state_p.def_impl(_run_state_impl)
def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...]):
del which_linear
# When we abstractly evaluate `run_state`, we want to keep track of which
# input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to
# "propagate" out its inner effects. Otherwise, the effects are local to this
# `run_state`.
is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)}
nonlocal_effects = {e for e in jaxpr.effects
if (isinstance(e, RefEffect) and e.input_index in is_ref)
or not isinstance(e, RefEffect)}
return avals, nonlocal_effects
run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval)
def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *,
jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]):
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
for _ in range(len(nonzero_tangents)):
_, out_nonzero_tangents = ad.jvp_jaxpr(
core.ClosedJaxpr(discharged_jaxpr, body_consts),
nonzero_tangents, instantiate=nonzero_tangents)
if out_nonzero_tangents == nonzero_tangents:
break
nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents)
else:
raise Exception("Invalid fixpoint")
del discharged_jaxpr, body_consts, out_nonzero_tangents
tangents = [ad.instantiate_zeros(t) if inst else t
for t, inst in zip(tangents, nonzero_tangents)]
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
closed_jvp_jaxpr, _ = ad.jvp_jaxpr(core.ClosedJaxpr(jaxpr, ()),
nonzero_tangents, [])
jvp_jaxpr_, jvp_consts = closed_jvp_jaxpr.jaxpr, closed_jvp_jaxpr.consts
jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_)
jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents))
out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr,
which_linear=jvp_which_linear)
out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts),
len(primals)])
del out_consts
out_tangents_iter = iter(out_tangents)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_primals, nonzero_tangents)]
return out_primals, out_tangents
ad.primitive_jvps[run_state_p] = _run_state_jvp
_save_everything = lambda *_, **__: True
def _convert_outputs_to_writes(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
in_avals = [v.aval for v in jaxpr.invars]
@lu.wrap_init
def eval_jaxpr(*refs):
# We split the refs into the original input refs and the dummy residual
# refs.
orig_refs, residual_refs = split_list(refs, [len(in_avals)])
residual_vals = core.eval_jaxpr(jaxpr, (), *orig_refs)
for res_ref, res_val in zip(residual_refs, residual_vals):
res_ref[...] = res_val
return []
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
else v.aval for v in jaxpr.outvars]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
for a in res_ref_avals]
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars"
@lu.wrap_init
def eval_jaxpr(*refs):
residual_refs, orig_refs = split_list(refs, [num_res])
residual_vals = [r[...] for r in residual_refs]
() = core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs)
return []
res_val_avals, orig_ref_avals = \
split_list([v.aval for v in jaxpr.invars], [num_res])
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
aval for aval in res_val_avals]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
return jaxpr
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]):
num_inputs = len(tracers)
assert num_inputs == len(jaxpr.invars)
in_unknowns = [not t.pval.is_known() for t in tracers]
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
# after running the for loop. We want to use the jaxpr to determine which
# `Ref`s are unknown after executing the for loop body given which `Ref`s are
# unknown before. However, the jaxpr has no outputs. Instead, we discharge
# the body and run the fixpoint with the discharged jaxpr. We can do this
# because the outputs of the jaxpr are one-to-one with the inputs.
discharged_jaxpr_, discharged_consts = discharge_state(jaxpr, ())
discharged_jaxpr = pe.convert_constvars_jaxpr(discharged_jaxpr_)
for _ in range(num_inputs):
jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns
_, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful(
discharged_jaxpr, jaxpr_in_unknowns, jaxpr_in_unknowns,
in_unknowns, False, _save_everything)
# assert out_inst == out_unknowns
out_unknowns = list(out_unknowns)
if out_unknowns == in_unknowns:
break
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
else:
raise Exception("Invalid fixpoint")
del out_unknowns # redundant since it's the same as `in_unknowns`
tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore
for t, uk in zip(tracers, in_unknowns))
# We use `partial_eval_jaxpr_stateful` here because it won't remove effectful
# primitives like `get`/`set`.
jaxpr_known_resout, jaxpr_unknown_resin_, _, _, num_res_out, num_res_ref = \
pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, in_inst=in_unknowns,
ensure_out_unknowns=[], ensure_out_inst=[],
saveable=_save_everything)
# # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref`
# and regular valued input/outputs. However, we'd like to bind these jaxprs to
# a `for`, which expects only `Ref` inputs and no output. We need to convert
# both of these jaxprs into ones that are compatible with `for`.
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
# to output residual values (none of them should be `Ref`s). We'll need to
# convert the output residual values into `Ref`s that are initially empty
# `Ref`s that are written to at the end of the jaxpr.
num_res = num_res_out + num_res_ref
num_invars = len(jaxpr_known_resout.invars) - num_res_ref
_, res_ref_avals = split_list(
[v.aval for v in jaxpr_known_resout.invars], [num_invars])
res_avals = [a.inner_aval for a in res_ref_avals] # pytype: disable=attribute-error
jaxpr_known, new_res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
# We now run the known jaxpr to obtain our residual values.
known_tracers, _ = partition_list(in_unknowns, tracers)
known_which_linear, _ = partition_list(in_unknowns, which_linear)
known_vals = [t.pval.get_known() for t in known_tracers]
all_res_avals = [*res_avals, *new_res_avals]
empty_res = map(ad_util.zeros_like_aval, all_res_avals)
jaxpr_known_args = [*known_vals, *empty_res]
jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res)
out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known,
which_linear=jaxpr_known_which_linear)
known_outputs, residuals = split_list(out_flat, [len(known_tracers)])
residuals = map(trace.new_instantiated_const, residuals)
ref_res, nonref_res = split_list(residuals, [num_res_ref])
# Now we handle the `jaxpr_unknown` that expects residual values as inputs.
# This jaxpr is the output of `partial_eval_jaxpr_stateful` that marks which
# inputs are actually used.
# `partial_eval_jaxpr_stateful` doesn't remove extra inputs/outputs for you
# so we use `dce_jaxpr` here to do that.
# To make it compatible with `for`, we need to convert those residual values
# into `Ref`s.
jaxpr_unknown = _convert_inputs_to_reads(len(new_res_avals),
jaxpr_unknown_resin_)
_, unknown_tracers = partition_list(in_unknowns, tracers)
_, uk_which_linear = partition_list(in_unknowns, which_linear)
unknown_which_linear = (False,) * num_res + tuple(uk_which_linear)
unknown_inputs = [*nonref_res, *ref_res, *unknown_tracers]
# Outputs match inputs so we construct output tracers that look like the input
# tracers.
res_ref_unknown_outputs = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None)
for t in unknown_inputs]
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(unknown_inputs) == len(res_ref_unknown_outputs)
assert len(unknown_inputs) == len(jaxpr_unknown.invars)
uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear)
_, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs],
**uk_params)
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
run_state_p, uk_params,
eqn_effects, source)
for t in res_ref_unknown_outputs: t.recipe = eqn
_, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res])
return merge_lists(in_unknowns, known_outputs, unknown_outputs)
pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval
def _run_state_partial_eval_custom(
saveable: Callable[..., bool],
in_unknowns: Sequence[bool],
in_inst: Sequence[bool],
eqn: core.JaxprEqn):
if not any(in_unknowns):
return eqn, None, in_unknowns, [False] * len(in_unknowns), []
jaxpr, which_linear = split_dict(eqn.params, ["jaxpr", "which_linear"])
num_inputs = len(eqn.invars)
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
# after running the for loop. However, the jaxpr has no outputs. Instead, we
# discharge the body and run the fixpoint with the discharged jaxpr. We can do
# this because the outputs of the discharged jaxpr are one-to-one with the
# inputs.
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
discharged_jaxpr = discharged_jaxpr.replace(
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
constvars=[])
in_unknowns, in_inst = list(in_unknowns), list(in_inst)
out_unknowns, out_inst = in_unknowns, in_unknowns
for _ in range(num_inputs):
jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns
_, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful(
discharged_jaxpr,
in_unknowns=jaxpr_in_unknowns,
in_inst=jaxpr_in_unknowns,
ensure_out_unknowns=in_unknowns,
ensure_out_inst=in_unknowns,
saveable=saveable)
out_unknowns = list(out_unknowns)
if out_unknowns == in_unknowns:
break
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
else:
if num_inputs > 0: raise Exception("Invalid fixpoint")
del out_unknowns # Redundant since it's the same as `in_unknowns`
new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst)
if type(x) is core.Var and inst and not already]
# We use `partial_eval_jaxpr_stateful` here because it won't remove effectful
# primitives like `get`/`set`.
jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res_out, num_res_ref = \
pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns,
in_unknowns, [], [], saveable)
num_res = num_res_ref + num_res_out
# `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` and
# non-Ref input/outputs. However, we'd like to bind these jaxprs to a
# `for`, which expects only `Ref` inputs and no output. We need to convert
# both of these jaxprs into ones that are compatible with `for`.
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
# to output residual values (none of them should be `Ref`s). We'll need to
# convert the output residual values into `Ref`s that are initially empty
# `Ref`s that are written to at the end of the jaxpr.
jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
# In a stateful partial_eval, the residuals should be `Ref`s.
res_avals = map(AbstractRef, res_avals) # type: ignore
known_invars, staged_invars = partition_list(in_unknowns, eqn.invars)
known_outvars, staged_outvars = partition_list(in_unknowns, eqn.outvars)
newvar = core.gensym()
_, res_ref_avals = split_list([v.aval for v in jaxpr_known_resout.invars],
[len(known_invars)])
nonref_resvars = map(newvar, res_avals)
ref_resvars = map(newvar, res_ref_avals)
known_out_resvars = map(newvar, [*res_ref_avals, *res_avals])
known_which_linear, _ = partition_list(in_unknowns, which_linear)
jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res)
known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars]
known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear)
_, known_effects = run_state_p.abstract_eval(
*[v.aval for v in known_and_res_invars], **known_params)
eqn_known = pe.new_jaxpr_eqn(known_and_res_invars,
[*known_outvars, *known_out_resvars],
run_state_p, known_params,
known_effects, eqn.source_info)
jaxpr_staged = _convert_inputs_to_reads(len(res_avals), jaxpr_staged_resin_)
_, staged_which_linear = partition_list(in_unknowns, which_linear)
which_linear_unknown = (*[False] * num_res, *staged_which_linear)
staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown)
rejiggered_resvars = [*nonref_resvars, *ref_resvars]
_, staged_invars = partition_list(in_unknowns, eqn.invars)
res_staged_invars = [*rejiggered_resvars, *staged_invars]
_, staged_effects = run_state_p.abstract_eval(
*[v.aval for v in res_staged_invars], **staged_params)
_, staged_outvars = partition_list(in_unknowns, eqn.outvars)
if num_res:
@lu.wrap_init
def staged(*args):
out = run_state_p.bind(*args, **staged_params)
return out[num_res:]
staged_call_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(staged,
[v.aval for v in res_staged_invars])
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
staged_outvars,
core.closed_call_p,
dict(call_jaxpr=core.ClosedJaxpr(staged_call_jaxpr, ())),
staged_effects, eqn.source_info)
assert len(res_staged_invars) == len(staged_call_jaxpr.invars)
assert len(staged_outvars) == len(staged_call_jaxpr.outvars)
else:
eqn_staged = pe.new_jaxpr_eqn(staged_invars,
staged_outvars,
run_state_p,
staged_params,
staged_effects, eqn.source_info)
new_vars = [*new_inst, *nonref_resvars, *ref_resvars]
return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars
pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom
def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool]
) -> tuple[core.Jaxpr, Any]:
def trans(*args):
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.
res_jaxpr_, tangent_jaxpr_, *_, num_res_out, num_res_ref = \
pe.partial_eval_jaxpr_stateful(jaxpr, which_linear, in_inst=which_linear,
ensure_out_inst=[],
ensure_out_unknowns=[],
saveable=_save_everything)
num_unknown = sum(which_linear)
num_known = len(jaxpr.invars) - num_unknown
res_args, _ = partition_list(which_linear, args)
res_jaxpr_avals = [v.aval for v in res_jaxpr_.invars]
_, res_avals = split_list(res_jaxpr_avals, [num_known])
res_avals = [a.inner_aval for a in res_avals] # pytype: disable=attribute-error
all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]]
empty_res = map(ad.zeros_like_aval, all_avals)
res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_)
res = run_state_p.bind(*res_args, *empty_res, jaxpr=res_jaxpr,
which_linear=(False,) * (len(res_args) + len(empty_res)))
res = res[len(res_args):]
ref_res_, nonref_res_ = split_list(res, [num_res_ref])
# Now that we have residual values, we run the tangent jaxpr. It takes as
# input the residuals, the loop index, and all the refs (at least, the ones
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
# unknown inputs!
tangent_jaxpr, used_inputs = pe.dce_jaxpr(tangent_jaxpr_, [])
used_res, used_cts = split_list(used_inputs, [len(res)])
used_nonref_res, used_ref_res = split_list(used_res, [num_res_out])
_, nonref_res = partition_list(used_nonref_res, nonref_res_)
_, ref_res = partition_list(used_ref_res, ref_res_)
primals_args = [*nonref_res, *ref_res]
_, tangent_args = partition_list(which_linear, args)
_, ct_args = partition_list(used_cts, tangent_args)
ad.backward_pass(
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
return []
jaxpr_trans, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans, consts
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...]):
# if any in_ct is nonzero, we definitely want it in args_ (and the
# corresponding x in args could be an undefined primal, but doesnt have to be)
# for non-res stuff:
# getting and setting => (nonzero ct, UndefinedPrimal arg)
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
# just getting => (zero ct , UndefinedPrimal arg)
# for res stuff:
# (zero ct , not UndefinedPrimal)
assert any(which_linear)
transpose_args = []
for x, ct in zip(args, in_cts):
if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x):
# this is a residual, take x!
transpose_args.append(x)
elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x):
# the loop was 'just getting', plug in a zero
transpose_args.append(ad_util.zeros_like_aval(x.aval))
elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x):
# the loop was 'just setting', grab that cotangent! x is dummy
transpose_args.append(ct)
elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x):
# the loop was 'getting and setting', grab that cotangent!
transpose_args.append(ct)
jaxpr_transpose_, consts = _transpose_jaxpr(jaxpr, which_linear)
jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_)
which_linear = (*[False] * len(consts), *which_linear)
const_all_outs = run_state_p.bind(*consts, *transpose_args,
jaxpr=jaxpr_transpose,
which_linear=which_linear)
_, all_outs = split_list(const_all_outs, [len(consts)])
ct_outs = [ct if ad.is_undefined_primal(x) else None
for x, ct in zip(args, all_outs)]
return ct_outs
ad.primitive_transposes[run_state_p] = _run_state_transpose
@register_discharge_rule(run_state_p)
def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
*args: Any, jaxpr: core.Jaxpr,
which_linear: Sequence[bool]):
del out_avals
out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear)
new_invals = []
for aval, out_val in zip(in_avals, out_vals):
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
return new_invals, out_vals
def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
@weakref_lru_cache
def _initial_style_jaxpr(fun, in_tree, in_avals):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
tree_util.treedef_tuple((in_tree,)))
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree_thunk()
def run_state(f: Callable[..., None]):
def wrapped(args):
flat_args, in_tree = tree_util.tree_flatten(args)
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args]
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals))
jaxpr = hoist_consts_to_refs(jaxpr_)
which_linear = (False,) * (len(consts) + len(flat_args))
out_const_flat = run_state_p.bind(*consts, *flat_args, jaxpr=jaxpr,
which_linear=which_linear)
_, out_flat = split_list(out_const_flat, [len(consts)])
return in_tree.unflatten(out_flat)
return wrapped
def run_state_reference(f: Callable[..., None]):
def wrapped(args):
flat_args, in_tree = tree_util.tree_flatten(args)
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args]
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals))
jaxpr = hoist_consts_to_refs(jaxpr_)
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
*consts, *args)
_, out_flat = split_list(out_const_flat, [len(consts)])
return in_tree.unflatten(out_flat)
return wrapped

View File

@ -18,8 +18,7 @@ from typing import Any, List, Tuple, Union
import numpy as np
from jax import lax
import jax
from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
@ -421,7 +420,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims):
# `idxs` doesn't include the non indexed dims.
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
else:
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
@ -454,7 +453,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
bdim_out = 0
@ -487,7 +486,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
idxs_shape, = {i.shape for i in idxs} or [()]
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []

View File

@ -20,13 +20,8 @@ from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
from jax._src import core
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.util import safe_map, safe_zip
xc = xla_client
xb = xla_bridge
## JAX utilities
map, unsafe_map = safe_map, map
@ -107,25 +102,25 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
@core.aval_method
@staticmethod
def get(tracer, idx=()):
from jax._src.state.primitives import ref_get
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx)
@core.aval_method
@staticmethod
def set(tracer, value, idx=()):
from jax._src.state.primitives import ref_set
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(tracer, idx, value)
def _getitem(self, tracer, idx) -> Array:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_get
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx)
def _setitem(self, tracer, idx, value) -> None:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_set
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(tracer, idx, value)
def __repr__(self) -> str:

54
jax/_src/state/utils.py Normal file
View File

@ -0,0 +1,54 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for tracing stateful functions."""
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src.state import AbstractRef
from jax._src.util import (partition_list, merge_lists, split_list, safe_map,
safe_zip)
from jax._src.state.primitives import ref_get
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
all_const_avals = [var.aval for var in jaxpr.constvars]
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
jaxpr.constvars]
const_avals_, const_ref_avals = partition_list(is_const_ref, all_const_avals)
const_avals = map(AbstractRef, const_avals_)
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
arg_avals = [var.aval for var in jaxpr.invars]
in_avals = [*merged_const_avals, *arg_avals]
num_consts = len(merged_const_avals)
def _hoist(*consts_args):
all_consts, args = split_list(consts_args, [num_consts])
consts, const_refs = partition_list(is_const_ref, all_consts)
# We immediately read the const values out of the `Ref`s.
consts = map(lambda x: ref_get(x, ()), consts)
all_consts = merge_lists(is_const_ref, consts, const_refs)
return core.eval_jaxpr(jaxpr, all_consts, *args)
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
def val_to_ref_aval(x) -> AbstractRef:
aval = core.raise_to_shaped(core.get_aval(x))
if type(aval) is not core.ShapedArray:
raise Exception(f"can't make ref from {x}")
return AbstractRef(aval)

View File

@ -998,10 +998,23 @@ jax_test(
jax_test(
name = "state_test",
srcs = ["state_test.py"],
# Use fewer cases to prevent timeouts.
args = [
"--jax_num_generated_cases=5",
],
backend_variant_args = {
"tpu_pjrt_c_api": ["--jax_num_generated_cases=1"],
},
enable_configs = [
"gpu",
"cpu",
],
shard_count = {
"cpu": 2,
"gpu": 2,
"tpu": 2,
},
deps = py_deps("hypothesis"),
)
jax_test(

File diff suppressed because it is too large Load Diff