mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
This commit is contained in:
commit
3c1f3abba2
15
jax/BUILD
15
jax/BUILD
@ -364,6 +364,7 @@ pytype_strict_library(
|
||||
":effects",
|
||||
":profiler",
|
||||
":source_info_util",
|
||||
":state_types",
|
||||
":tree_util",
|
||||
":util",
|
||||
] + py_deps("numpy"),
|
||||
@ -411,6 +412,20 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "state_types",
|
||||
srcs = [
|
||||
"_src/state/__init__.py",
|
||||
"_src/state/types.py",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":effects",
|
||||
":pretty_printer",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "tree_util",
|
||||
srcs = ["_src/tree_util.py"],
|
||||
|
@ -42,6 +42,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
KeyPath, generate_key_paths, keystr)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
@ -1123,9 +1124,36 @@ def partial_eval_jaxpr_custom(
|
||||
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
|
||||
if type(ensure_out_inst) is bool:
|
||||
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
|
||||
return _partial_eval_jaxpr_custom_cached(
|
||||
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
|
||||
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
|
||||
tuple(in_inst),
|
||||
tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
if num_res_ref > 0:
|
||||
raise ValueError(
|
||||
"Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res
|
||||
|
||||
def partial_eval_jaxpr_stateful(
|
||||
jaxpr: Jaxpr,
|
||||
in_unknowns: Sequence[bool],
|
||||
in_inst: Union[bool, Sequence[bool]],
|
||||
ensure_out_unknowns: Union[bool, Sequence[bool]],
|
||||
ensure_out_inst: Union[bool, Sequence[bool]],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
if type(ensure_out_unknowns) is bool:
|
||||
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
|
||||
if type(ensure_out_inst) is bool:
|
||||
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
|
||||
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
|
||||
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
|
||||
tuple(in_inst),
|
||||
tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref
|
||||
|
||||
@weakref_lru_cache
|
||||
def _partial_eval_jaxpr_custom_cached(
|
||||
@ -1135,9 +1163,10 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
ensure_out_unknowns: Tuple[bool, ...],
|
||||
ensure_out_inst: Tuple[bool, ...],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
|
||||
env: Dict[Var, Tuple[bool, bool]] = {}
|
||||
residuals: OrderedSet[Var] = OrderedSet()
|
||||
residual_refs: OrderedSet[Var] = OrderedSet()
|
||||
|
||||
def read(x: Atom) -> Tuple[bool, bool]:
|
||||
if type(x) is Var:
|
||||
@ -1162,7 +1191,11 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
if rule:
|
||||
eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn)
|
||||
eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore
|
||||
residuals.update(res)
|
||||
for r in res:
|
||||
if isinstance(r.aval, AbstractRef):
|
||||
residual_refs.add(r)
|
||||
else:
|
||||
residuals.add(r)
|
||||
map(write, unks_out, inst_out, eqn.outvars)
|
||||
elif any(unks_in):
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
@ -1189,24 +1222,27 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
|
||||
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
|
||||
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
|
||||
ref_res_is_input = [r in ins_known for r in residual_refs]
|
||||
non_input_res_refs, _ = partition_list(ref_res_is_input, list(residual_refs))
|
||||
ins_known_and_ref_res = [*ins_known, *non_input_res_refs]
|
||||
known_outvars = [*outs_known, *residuals]
|
||||
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known, known_outvars,
|
||||
known_eqns)
|
||||
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, known_outvars,
|
||||
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
|
||||
known_outvars, known_eqns)
|
||||
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
|
||||
known_eqns, known_effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
|
||||
|
||||
_, ins_staged = partition_list(in_inst, jaxpr.invars)
|
||||
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
|
||||
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
|
||||
staged_invars = [*residuals, *ins_staged]
|
||||
staged_invars = [*residuals, *non_input_res_refs, *ins_staged]
|
||||
staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars,
|
||||
outs_staged, staged_eqns)
|
||||
outs_staged, staged_eqns)
|
||||
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
|
||||
outs_staged, staged_eqns, staged_effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
|
||||
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
|
||||
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
|
||||
len(non_input_res_refs))
|
||||
|
||||
# A primitive rule for policy-driven partial evaluation returns a 5-tuple
|
||||
# with the components representing, respectively:
|
||||
@ -1283,9 +1319,10 @@ def closed_call_partial_eval_custom_rule(
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
|
||||
partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in,
|
||||
False, False, saveable)
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \
|
||||
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
|
||||
False, False, saveable)
|
||||
num_res = num_res_ref + num_res_out
|
||||
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
|
||||
@ -1299,18 +1336,25 @@ def closed_call_partial_eval_custom_rule(
|
||||
params_known, params_staged = params_updater(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
|
||||
params_staged)
|
||||
residuals = [newvar(res_aval(params_known, a))
|
||||
for a in jaxpr_staged.in_avals[:num_res]]
|
||||
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
residuals, ref_residuals = split_list(
|
||||
[newvar(res_aval(params_known, v)) for v
|
||||
in jaxpr_staged.in_avals[:num_res]], [num_res_out])
|
||||
eqn_known = new_jaxpr_eqn([*ins_known, *ref_residuals],
|
||||
[*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ref_residuals, *ins_staged],
|
||||
out_binders_staged,
|
||||
eqn.primitive, params_staged, jaxpr_staged.effects,
|
||||
eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
|
||||
assert len(ins_known) + len(ref_residuals) == len(jaxpr_known.jaxpr.invars)
|
||||
assert len(ins_staged) + len(ref_residuals) + len(residuals) == len(jaxpr_staged.jaxpr.invars)
|
||||
assert len(out_binders_known) + len(residuals) == len(jaxpr_known.jaxpr.outvars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
new_vars = [*new_inst, *residuals, *ref_residuals]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||
|
||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
@ -1536,16 +1580,25 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
del outvars
|
||||
jaxpr_effects = set()
|
||||
all_vars = [*constvars, *invars]
|
||||
for eqn in eqns:
|
||||
for eff in eqn.effects:
|
||||
if isinstance(eff, effects.JaxprInputEffect):
|
||||
if eff.input_index >= len(eqn.invars):
|
||||
raise ValueError(
|
||||
f"`JaxprInputEffect` {eff} is invalid."
|
||||
f"\n Equation: {eqn}\n"
|
||||
"\n Jaxpr: "
|
||||
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
|
||||
invar = eqn.invars[eff.input_index]
|
||||
if invar not in all_vars:
|
||||
raise ValueError(
|
||||
"`JaxprInputEffect` does not have corresponding input.")
|
||||
f"`JaxprInputEffect` {eff} does not have "
|
||||
f"corresponding input: {invar}."
|
||||
f"\n Equation: {eqn}\n"
|
||||
"\n Jaxpr: "
|
||||
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
|
||||
eff = eff.replace(input_index=all_vars.index(invar))
|
||||
jaxpr_effects.add(eff)
|
||||
return jaxpr_effects
|
||||
|
@ -31,7 +31,8 @@ from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import state
|
||||
from jax._src.state.discharge import register_discharge_rule, discharge_state
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -237,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
|
||||
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
|
||||
raise ValueError("Cannot pass `Ref`s into `cond`.")
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
out_tree, false_out_tree = out_trees
|
||||
@ -338,7 +339,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
|
||||
index, *ops = args
|
||||
index_dim, *op_dims = dims
|
||||
# TODO(sharadmv): clean this up by adding a specific blocklist
|
||||
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
|
||||
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
||||
branch.jaxpr.effects):
|
||||
raise NotImplementedError(
|
||||
"State effect not supported in vmap-of-cond.")
|
||||
@ -426,7 +427,7 @@ def _cond_jvp(primals, tangents, branches, linear):
|
||||
def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
in_unknowns = [t.pval[0] is not None for t in tracers]
|
||||
index_uk, *ops_uk = in_unknowns
|
||||
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
|
||||
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
||||
branch.jaxpr.effects):
|
||||
raise NotImplementedError(
|
||||
"State effect not supported in cond partial-eval.")
|
||||
@ -703,7 +704,7 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
||||
linear = [type(x) is ad.UndefinedPrimal for x in ops]
|
||||
in_avals = map(raise_to_shaped, branches[0].in_avals)
|
||||
num_res = len(ops) - sum(linear)
|
||||
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
|
||||
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
||||
branch.jaxpr.effects):
|
||||
raise NotImplementedError("State effect not supported in cond transpose.")
|
||||
|
||||
@ -856,10 +857,10 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
|
||||
mlir.register_lowering(cond_p, _cond_lowering)
|
||||
|
||||
@state.register_discharge_rule(cond_p)
|
||||
@register_discharge_rule(cond_p)
|
||||
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
|
||||
discharged_branches = tuple(
|
||||
core.ClosedJaxpr(state.discharge_state(branch.jaxpr, ())[0], ())
|
||||
core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ())
|
||||
for branch in branches)
|
||||
out_vals = cond_p.bind(*args, branches=discharged_branches, linear=linear)
|
||||
out_ref_vals, out_vals = util.split_list(
|
||||
@ -868,5 +869,5 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
|
||||
new_invals = []
|
||||
for aval in in_avals:
|
||||
new_invals.append(
|
||||
next(ref_val_iter) if isinstance(aval, state.AbstractRef) else None)
|
||||
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
|
||||
return new_invals, out_vals
|
||||
|
@ -33,7 +33,11 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect)
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state import utils as state_utils
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list, split_dict)
|
||||
from jax._src.lax.control_flow import loops
|
||||
@ -50,15 +54,10 @@ T = TypeVar('T')
|
||||
class Ref(Generic[T]): pass
|
||||
Array = Any
|
||||
|
||||
ReadEffect = state.ReadEffect
|
||||
WriteEffect = state.WriteEffect
|
||||
AccumEffect = state.AccumEffect
|
||||
StateEffect = state.StateEffect
|
||||
AbstractRef = state.AbstractRef
|
||||
ref_set = state.ref_set
|
||||
ref_get = state.ref_get
|
||||
ref_addupdate = state.ref_addupdate
|
||||
discharge_state = state.discharge_state
|
||||
ref_set = state_primitives.ref_set
|
||||
ref_get = state_primitives.ref_get
|
||||
ref_addupdate = state_primitives.ref_addupdate
|
||||
discharge_state = state_discharge.discharge_state
|
||||
|
||||
|
||||
## `for_loop` implementation
|
||||
@ -100,12 +99,6 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||||
f, state_avals)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
def val_to_ref_aval(x) -> AbstractRef:
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
if type(aval) is not core.ShapedArray:
|
||||
raise Exception(f"can't make ref from {x}")
|
||||
return AbstractRef(aval)
|
||||
|
||||
def for_loop(nsteps: Union[int, Sequence[int]],
|
||||
body: Callable[[Array, Ref[S]], None], init_state: S,
|
||||
*, reverse: bool = False, unroll: int = 1) -> S:
|
||||
@ -151,14 +144,14 @@ def for_loop(nsteps: Union[int, Sequence[int]],
|
||||
if len(nsteps) > 1:
|
||||
outer_step, *rest_steps = nsteps
|
||||
def wrapped_body(i, refs):
|
||||
vals = tree_map(lambda ref: state.ref_get(ref, ()), refs)
|
||||
vals = tree_map(lambda ref: ref_get(ref, ()), refs)
|
||||
vals = for_loop(
|
||||
rest_steps, functools.partial(body, i), vals, unroll=unroll)
|
||||
tree_map(lambda ref, val: state.ref_set(ref, (), val), refs, vals)
|
||||
tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals)
|
||||
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
|
||||
nsteps, = nsteps
|
||||
flat_state, state_tree = tree_flatten(init_state)
|
||||
state_avals = map(val_to_ref_aval, flat_state)
|
||||
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||||
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
@ -247,7 +240,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
@for_p.def_effectful_abstract_eval
|
||||
def _for_abstract_eval(*avals, jaxpr, **__):
|
||||
# Find out for each of the `Ref`s in our jaxpr what effects they have.
|
||||
jaxpr_aval_effects = state.get_ref_state_effects(
|
||||
jaxpr_aval_effects = state_types.get_ref_state_effects(
|
||||
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
|
||||
aval_effects = [set(eff.replace(input_index=eff.input_index - 1)
|
||||
for eff in effs) for aval, effs
|
||||
@ -256,7 +249,7 @@ def _for_abstract_eval(*avals, jaxpr, **__):
|
||||
nonlocal_state_effects = core.join_effects(*aval_effects)
|
||||
return list(avals), nonlocal_state_effects
|
||||
|
||||
@state.register_discharge_rule(for_p)
|
||||
@state_discharge.register_discharge_rule(for_p)
|
||||
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
|
||||
reverse: bool, which_linear: Sequence[bool],
|
||||
nsteps: int, unroll: int
|
||||
@ -388,7 +381,7 @@ def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
|
||||
|
||||
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
|
||||
# Get effects for each of the jaxpr inputs and remove the loop index.
|
||||
ref_effects = state.get_ref_state_effects(
|
||||
ref_effects = state_types.get_ref_state_effects(
|
||||
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
|
||||
# We first assume that *read-only `Ref`s* are loop-invariant. We can safely do
|
||||
# this because the only way something can be loop-varying is if we write to it
|
||||
@ -771,7 +764,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
|
||||
Potentially useful for testing and benchmarking.
|
||||
"""
|
||||
flat_state, state_tree = tree_flatten(init_state)
|
||||
state_avals = map(val_to_ref_aval, flat_state)
|
||||
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||||
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
|
@ -15,7 +15,3 @@
|
||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||||
AccumEffect, StateEffect, RefEffect,
|
||||
get_ref_state_effects, shaped_array_ref)
|
||||
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
|
||||
ref_addupdate, get_p, swap_p,
|
||||
addupdate_p)
|
||||
from jax._src.state.discharge import discharge_state, register_discharge_rule
|
||||
|
@ -15,24 +15,34 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
from jax._src.state.utils import hoist_consts_to_refs
|
||||
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
|
||||
partition_list, merge_lists, split_dict)
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
PyTreeDef = tree_util.PyTreeDef
|
||||
|
||||
## Discharging state
|
||||
|
||||
@ -241,3 +251,462 @@ def _closed_call_discharge_rule(
|
||||
else None for aval in in_avals)
|
||||
assert next(ref_vals_iter, None) is None
|
||||
return new_invals, out_vals
|
||||
|
||||
# # `run_state`
|
||||
|
||||
run_state_p = core.Primitive("run_state")
|
||||
run_state_p.multiple_results = True
|
||||
|
||||
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
|
||||
which_linear: tuple[bool, ...]):
|
||||
if config.jax_enable_checks:
|
||||
core.check_jaxpr(jaxpr)
|
||||
assert len(jaxpr.invars) == len(args)
|
||||
assert len(which_linear) == len(args)
|
||||
return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr,
|
||||
which_linear=which_linear)
|
||||
run_state_p.def_custom_bind(_run_state_bind)
|
||||
|
||||
def _run_state_impl(*args: Any, jaxpr: core.Jaxpr,
|
||||
which_linear: tuple[bool, ...]):
|
||||
del which_linear
|
||||
discharged_jaxpr, consts = discharge_state(jaxpr, ())
|
||||
return core.eval_jaxpr(discharged_jaxpr, consts, *args)
|
||||
run_state_p.def_impl(_run_state_impl)
|
||||
|
||||
def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr,
|
||||
which_linear: tuple[bool, ...]):
|
||||
del which_linear
|
||||
# When we abstractly evaluate `run_state`, we want to keep track of which
|
||||
# input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to
|
||||
# "propagate" out its inner effects. Otherwise, the effects are local to this
|
||||
# `run_state`.
|
||||
is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)}
|
||||
nonlocal_effects = {e for e in jaxpr.effects
|
||||
if (isinstance(e, RefEffect) and e.input_index in is_ref)
|
||||
or not isinstance(e, RefEffect)}
|
||||
return avals, nonlocal_effects
|
||||
run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval)
|
||||
|
||||
def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *,
|
||||
jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]):
|
||||
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||||
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
|
||||
for _ in range(len(nonzero_tangents)):
|
||||
_, out_nonzero_tangents = ad.jvp_jaxpr(
|
||||
core.ClosedJaxpr(discharged_jaxpr, body_consts),
|
||||
nonzero_tangents, instantiate=nonzero_tangents)
|
||||
if out_nonzero_tangents == nonzero_tangents:
|
||||
break
|
||||
nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents)
|
||||
else:
|
||||
raise Exception("Invalid fixpoint")
|
||||
del discharged_jaxpr, body_consts, out_nonzero_tangents
|
||||
tangents = [ad.instantiate_zeros(t) if inst else t
|
||||
for t, inst in zip(tangents, nonzero_tangents)]
|
||||
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||
closed_jvp_jaxpr, _ = ad.jvp_jaxpr(core.ClosedJaxpr(jaxpr, ()),
|
||||
nonzero_tangents, [])
|
||||
jvp_jaxpr_, jvp_consts = closed_jvp_jaxpr.jaxpr, closed_jvp_jaxpr.consts
|
||||
jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_)
|
||||
jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents))
|
||||
out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr,
|
||||
which_linear=jvp_which_linear)
|
||||
out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts),
|
||||
len(primals)])
|
||||
del out_consts
|
||||
out_tangents_iter = iter(out_tangents)
|
||||
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
||||
for p, nz in zip(out_primals, nonzero_tangents)]
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[run_state_p] = _run_state_jvp
|
||||
|
||||
_save_everything = lambda *_, **__: True
|
||||
|
||||
def _convert_outputs_to_writes(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
|
||||
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
||||
|
||||
in_avals = [v.aval for v in jaxpr.invars]
|
||||
@lu.wrap_init
|
||||
def eval_jaxpr(*refs):
|
||||
# We split the refs into the original input refs and the dummy residual
|
||||
# refs.
|
||||
orig_refs, residual_refs = split_list(refs, [len(in_avals)])
|
||||
residual_vals = core.eval_jaxpr(jaxpr, (), *orig_refs)
|
||||
for res_ref, res_val in zip(residual_refs, residual_vals):
|
||||
res_ref[...] = res_val
|
||||
return []
|
||||
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
|
||||
else v.aval for v in jaxpr.outvars]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||||
assert not consts
|
||||
return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
|
||||
for a in res_ref_avals]
|
||||
|
||||
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
assert not jaxpr.constvars, "Jaxpr should not have constvars"
|
||||
|
||||
@lu.wrap_init
|
||||
def eval_jaxpr(*refs):
|
||||
residual_refs, orig_refs = split_list(refs, [num_res])
|
||||
residual_vals = [r[...] for r in residual_refs]
|
||||
() = core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs)
|
||||
return []
|
||||
|
||||
res_val_avals, orig_ref_avals = \
|
||||
split_list([v.aval for v in jaxpr.invars], [num_res])
|
||||
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
|
||||
aval for aval in res_val_avals]
|
||||
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
|
||||
return jaxpr
|
||||
|
||||
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||||
jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]):
|
||||
num_inputs = len(tracers)
|
||||
assert num_inputs == len(jaxpr.invars)
|
||||
in_unknowns = [not t.pval.is_known() for t in tracers]
|
||||
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
||||
# after running the for loop. We want to use the jaxpr to determine which
|
||||
# `Ref`s are unknown after executing the for loop body given which `Ref`s are
|
||||
# unknown before. However, the jaxpr has no outputs. Instead, we discharge
|
||||
# the body and run the fixpoint with the discharged jaxpr. We can do this
|
||||
# because the outputs of the jaxpr are one-to-one with the inputs.
|
||||
discharged_jaxpr_, discharged_consts = discharge_state(jaxpr, ())
|
||||
discharged_jaxpr = pe.convert_constvars_jaxpr(discharged_jaxpr_)
|
||||
for _ in range(num_inputs):
|
||||
jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns
|
||||
_, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful(
|
||||
discharged_jaxpr, jaxpr_in_unknowns, jaxpr_in_unknowns,
|
||||
in_unknowns, False, _save_everything)
|
||||
# assert out_inst == out_unknowns
|
||||
out_unknowns = list(out_unknowns)
|
||||
if out_unknowns == in_unknowns:
|
||||
break
|
||||
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
||||
else:
|
||||
raise Exception("Invalid fixpoint")
|
||||
del out_unknowns # redundant since it's the same as `in_unknowns`
|
||||
tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore
|
||||
for t, uk in zip(tracers, in_unknowns))
|
||||
|
||||
# We use `partial_eval_jaxpr_stateful` here because it won't remove effectful
|
||||
# primitives like `get`/`set`.
|
||||
jaxpr_known_resout, jaxpr_unknown_resin_, _, _, num_res_out, num_res_ref = \
|
||||
pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, in_inst=in_unknowns,
|
||||
ensure_out_unknowns=[], ensure_out_inst=[],
|
||||
saveable=_save_everything)
|
||||
# # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref`
|
||||
# and regular valued input/outputs. However, we'd like to bind these jaxprs to
|
||||
# a `for`, which expects only `Ref` inputs and no output. We need to convert
|
||||
# both of these jaxprs into ones that are compatible with `for`.
|
||||
|
||||
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
||||
# to output residual values (none of them should be `Ref`s). We'll need to
|
||||
# convert the output residual values into `Ref`s that are initially empty
|
||||
# `Ref`s that are written to at the end of the jaxpr.
|
||||
num_res = num_res_out + num_res_ref
|
||||
|
||||
num_invars = len(jaxpr_known_resout.invars) - num_res_ref
|
||||
_, res_ref_avals = split_list(
|
||||
[v.aval for v in jaxpr_known_resout.invars], [num_invars])
|
||||
res_avals = [a.inner_aval for a in res_ref_avals] # pytype: disable=attribute-error
|
||||
jaxpr_known, new_res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
|
||||
# We now run the known jaxpr to obtain our residual values.
|
||||
known_tracers, _ = partition_list(in_unknowns, tracers)
|
||||
known_which_linear, _ = partition_list(in_unknowns, which_linear)
|
||||
known_vals = [t.pval.get_known() for t in known_tracers]
|
||||
all_res_avals = [*res_avals, *new_res_avals]
|
||||
empty_res = map(ad_util.zeros_like_aval, all_res_avals)
|
||||
jaxpr_known_args = [*known_vals, *empty_res]
|
||||
|
||||
jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res)
|
||||
out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known,
|
||||
which_linear=jaxpr_known_which_linear)
|
||||
known_outputs, residuals = split_list(out_flat, [len(known_tracers)])
|
||||
residuals = map(trace.new_instantiated_const, residuals)
|
||||
ref_res, nonref_res = split_list(residuals, [num_res_ref])
|
||||
|
||||
# Now we handle the `jaxpr_unknown` that expects residual values as inputs.
|
||||
# This jaxpr is the output of `partial_eval_jaxpr_stateful` that marks which
|
||||
# inputs are actually used.
|
||||
# `partial_eval_jaxpr_stateful` doesn't remove extra inputs/outputs for you
|
||||
# so we use `dce_jaxpr` here to do that.
|
||||
# To make it compatible with `for`, we need to convert those residual values
|
||||
# into `Ref`s.
|
||||
jaxpr_unknown = _convert_inputs_to_reads(len(new_res_avals),
|
||||
jaxpr_unknown_resin_)
|
||||
_, unknown_tracers = partition_list(in_unknowns, tracers)
|
||||
_, uk_which_linear = partition_list(in_unknowns, which_linear)
|
||||
unknown_which_linear = (False,) * num_res + tuple(uk_which_linear)
|
||||
unknown_inputs = [*nonref_res, *ref_res, *unknown_tracers]
|
||||
# Outputs match inputs so we construct output tracers that look like the input
|
||||
# tracers.
|
||||
res_ref_unknown_outputs = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None)
|
||||
for t in unknown_inputs]
|
||||
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
|
||||
assert len(unknown_inputs) == len(res_ref_unknown_outputs)
|
||||
assert len(unknown_inputs) == len(jaxpr_unknown.invars)
|
||||
uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear)
|
||||
_, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs],
|
||||
**uk_params)
|
||||
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
|
||||
run_state_p, uk_params,
|
||||
eqn_effects, source)
|
||||
for t in res_ref_unknown_outputs: t.recipe = eqn
|
||||
_, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res])
|
||||
return merge_lists(in_unknowns, known_outputs, unknown_outputs)
|
||||
pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval
|
||||
|
||||
def _run_state_partial_eval_custom(
|
||||
saveable: Callable[..., bool],
|
||||
in_unknowns: Sequence[bool],
|
||||
in_inst: Sequence[bool],
|
||||
eqn: core.JaxprEqn):
|
||||
if not any(in_unknowns):
|
||||
return eqn, None, in_unknowns, [False] * len(in_unknowns), []
|
||||
jaxpr, which_linear = split_dict(eqn.params, ["jaxpr", "which_linear"])
|
||||
num_inputs = len(eqn.invars)
|
||||
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
||||
# after running the for loop. However, the jaxpr has no outputs. Instead, we
|
||||
# discharge the body and run the fixpoint with the discharged jaxpr. We can do
|
||||
# this because the outputs of the discharged jaxpr are one-to-one with the
|
||||
# inputs.
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
||||
discharged_jaxpr = discharged_jaxpr.replace(
|
||||
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
|
||||
constvars=[])
|
||||
in_unknowns, in_inst = list(in_unknowns), list(in_inst)
|
||||
out_unknowns, out_inst = in_unknowns, in_unknowns
|
||||
for _ in range(num_inputs):
|
||||
jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns
|
||||
_, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful(
|
||||
discharged_jaxpr,
|
||||
in_unknowns=jaxpr_in_unknowns,
|
||||
in_inst=jaxpr_in_unknowns,
|
||||
ensure_out_unknowns=in_unknowns,
|
||||
ensure_out_inst=in_unknowns,
|
||||
saveable=saveable)
|
||||
out_unknowns = list(out_unknowns)
|
||||
if out_unknowns == in_unknowns:
|
||||
break
|
||||
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
||||
else:
|
||||
if num_inputs > 0: raise Exception("Invalid fixpoint")
|
||||
del out_unknowns # Redundant since it's the same as `in_unknowns`
|
||||
new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst)
|
||||
if type(x) is core.Var and inst and not already]
|
||||
|
||||
# We use `partial_eval_jaxpr_stateful` here because it won't remove effectful
|
||||
# primitives like `get`/`set`.
|
||||
jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res_out, num_res_ref = \
|
||||
pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns,
|
||||
in_unknowns, [], [], saveable)
|
||||
num_res = num_res_ref + num_res_out
|
||||
# `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` and
|
||||
# non-Ref input/outputs. However, we'd like to bind these jaxprs to a
|
||||
# `for`, which expects only `Ref` inputs and no output. We need to convert
|
||||
# both of these jaxprs into ones that are compatible with `for`.
|
||||
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
|
||||
|
||||
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
||||
# to output residual values (none of them should be `Ref`s). We'll need to
|
||||
# convert the output residual values into `Ref`s that are initially empty
|
||||
# `Ref`s that are written to at the end of the jaxpr.
|
||||
jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
|
||||
|
||||
# In a stateful partial_eval, the residuals should be `Ref`s.
|
||||
res_avals = map(AbstractRef, res_avals) # type: ignore
|
||||
|
||||
known_invars, staged_invars = partition_list(in_unknowns, eqn.invars)
|
||||
known_outvars, staged_outvars = partition_list(in_unknowns, eqn.outvars)
|
||||
newvar = core.gensym()
|
||||
_, res_ref_avals = split_list([v.aval for v in jaxpr_known_resout.invars],
|
||||
[len(known_invars)])
|
||||
nonref_resvars = map(newvar, res_avals)
|
||||
ref_resvars = map(newvar, res_ref_avals)
|
||||
known_out_resvars = map(newvar, [*res_ref_avals, *res_avals])
|
||||
|
||||
known_which_linear, _ = partition_list(in_unknowns, which_linear)
|
||||
jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res)
|
||||
known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars]
|
||||
|
||||
known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear)
|
||||
_, known_effects = run_state_p.abstract_eval(
|
||||
*[v.aval for v in known_and_res_invars], **known_params)
|
||||
eqn_known = pe.new_jaxpr_eqn(known_and_res_invars,
|
||||
[*known_outvars, *known_out_resvars],
|
||||
run_state_p, known_params,
|
||||
known_effects, eqn.source_info)
|
||||
|
||||
jaxpr_staged = _convert_inputs_to_reads(len(res_avals), jaxpr_staged_resin_)
|
||||
|
||||
_, staged_which_linear = partition_list(in_unknowns, which_linear)
|
||||
which_linear_unknown = (*[False] * num_res, *staged_which_linear)
|
||||
staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown)
|
||||
rejiggered_resvars = [*nonref_resvars, *ref_resvars]
|
||||
_, staged_invars = partition_list(in_unknowns, eqn.invars)
|
||||
res_staged_invars = [*rejiggered_resvars, *staged_invars]
|
||||
_, staged_effects = run_state_p.abstract_eval(
|
||||
*[v.aval for v in res_staged_invars], **staged_params)
|
||||
_, staged_outvars = partition_list(in_unknowns, eqn.outvars)
|
||||
if num_res:
|
||||
@lu.wrap_init
|
||||
def staged(*args):
|
||||
out = run_state_p.bind(*args, **staged_params)
|
||||
return out[num_res:]
|
||||
staged_call_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(staged,
|
||||
[v.aval for v in res_staged_invars])
|
||||
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
|
||||
staged_outvars,
|
||||
core.closed_call_p,
|
||||
dict(call_jaxpr=core.ClosedJaxpr(staged_call_jaxpr, ())),
|
||||
staged_effects, eqn.source_info)
|
||||
assert len(res_staged_invars) == len(staged_call_jaxpr.invars)
|
||||
assert len(staged_outvars) == len(staged_call_jaxpr.outvars)
|
||||
else:
|
||||
eqn_staged = pe.new_jaxpr_eqn(staged_invars,
|
||||
staged_outvars,
|
||||
run_state_p,
|
||||
staged_params,
|
||||
staged_effects, eqn.source_info)
|
||||
new_vars = [*new_inst, *nonref_resvars, *ref_resvars]
|
||||
return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars
|
||||
pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom
|
||||
|
||||
def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool]
|
||||
) -> tuple[core.Jaxpr, Any]:
|
||||
def trans(*args):
|
||||
# First we want to run the computation to read all the residual refs. We can
|
||||
# do that by using partial evaluation with all linear inputs unknown.
|
||||
res_jaxpr_, tangent_jaxpr_, *_, num_res_out, num_res_ref = \
|
||||
pe.partial_eval_jaxpr_stateful(jaxpr, which_linear, in_inst=which_linear,
|
||||
ensure_out_inst=[],
|
||||
ensure_out_unknowns=[],
|
||||
saveable=_save_everything)
|
||||
|
||||
num_unknown = sum(which_linear)
|
||||
num_known = len(jaxpr.invars) - num_unknown
|
||||
res_args, _ = partition_list(which_linear, args)
|
||||
res_jaxpr_avals = [v.aval for v in res_jaxpr_.invars]
|
||||
_, res_avals = split_list(res_jaxpr_avals, [num_known])
|
||||
res_avals = [a.inner_aval for a in res_avals] # pytype: disable=attribute-error
|
||||
all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]]
|
||||
empty_res = map(ad.zeros_like_aval, all_avals)
|
||||
res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_)
|
||||
res = run_state_p.bind(*res_args, *empty_res, jaxpr=res_jaxpr,
|
||||
which_linear=(False,) * (len(res_args) + len(empty_res)))
|
||||
res = res[len(res_args):]
|
||||
ref_res_, nonref_res_ = split_list(res, [num_res_ref])
|
||||
|
||||
# Now that we have residual values, we run the tangent jaxpr. It takes as
|
||||
# input the residuals, the loop index, and all the refs (at least, the ones
|
||||
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
|
||||
# unknown inputs!
|
||||
tangent_jaxpr, used_inputs = pe.dce_jaxpr(tangent_jaxpr_, [])
|
||||
used_res, used_cts = split_list(used_inputs, [len(res)])
|
||||
used_nonref_res, used_ref_res = split_list(used_res, [num_res_out])
|
||||
_, nonref_res = partition_list(used_nonref_res, nonref_res_)
|
||||
_, ref_res = partition_list(used_ref_res, ref_res_)
|
||||
primals_args = [*nonref_res, *ref_res]
|
||||
_, tangent_args = partition_list(which_linear, args)
|
||||
_, ct_args = partition_list(used_cts, tangent_args)
|
||||
ad.backward_pass(
|
||||
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
|
||||
return []
|
||||
jaxpr_trans, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
||||
return jaxpr_trans, consts
|
||||
|
||||
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,
|
||||
which_linear: tuple[bool, ...]):
|
||||
# if any in_ct is nonzero, we definitely want it in args_ (and the
|
||||
# corresponding x in args could be an undefined primal, but doesnt have to be)
|
||||
# for non-res stuff:
|
||||
# getting and setting => (nonzero ct, UndefinedPrimal arg)
|
||||
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
|
||||
# just getting => (zero ct , UndefinedPrimal arg)
|
||||
# for res stuff:
|
||||
# (zero ct , not UndefinedPrimal)
|
||||
assert any(which_linear)
|
||||
transpose_args = []
|
||||
for x, ct in zip(args, in_cts):
|
||||
if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x):
|
||||
# this is a residual, take x!
|
||||
transpose_args.append(x)
|
||||
elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x):
|
||||
# the loop was 'just getting', plug in a zero
|
||||
transpose_args.append(ad_util.zeros_like_aval(x.aval))
|
||||
elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x):
|
||||
# the loop was 'just setting', grab that cotangent! x is dummy
|
||||
transpose_args.append(ct)
|
||||
elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x):
|
||||
# the loop was 'getting and setting', grab that cotangent!
|
||||
transpose_args.append(ct)
|
||||
jaxpr_transpose_, consts = _transpose_jaxpr(jaxpr, which_linear)
|
||||
jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_)
|
||||
which_linear = (*[False] * len(consts), *which_linear)
|
||||
const_all_outs = run_state_p.bind(*consts, *transpose_args,
|
||||
jaxpr=jaxpr_transpose,
|
||||
which_linear=which_linear)
|
||||
_, all_outs = split_list(const_all_outs, [len(consts)])
|
||||
ct_outs = [ct if ad.is_undefined_primal(x) else None
|
||||
for x, ct in zip(args, all_outs)]
|
||||
return ct_outs
|
||||
ad.primitive_transposes[run_state_p] = _run_state_transpose
|
||||
|
||||
@register_discharge_rule(run_state_p)
|
||||
def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue],
|
||||
*args: Any, jaxpr: core.Jaxpr,
|
||||
which_linear: Sequence[bool]):
|
||||
del out_avals
|
||||
out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear)
|
||||
new_invals = []
|
||||
for aval, out_val in zip(in_avals, out_vals):
|
||||
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
|
||||
return new_invals, out_vals
|
||||
|
||||
def initial_style_jaxpr(
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
|
||||
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
def run_state(f: Callable[..., None]):
|
||||
def wrapped(args):
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args]
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals))
|
||||
jaxpr = hoist_consts_to_refs(jaxpr_)
|
||||
which_linear = (False,) * (len(consts) + len(flat_args))
|
||||
out_const_flat = run_state_p.bind(*consts, *flat_args, jaxpr=jaxpr,
|
||||
which_linear=which_linear)
|
||||
_, out_flat = split_list(out_const_flat, [len(consts)])
|
||||
return in_tree.unflatten(out_flat)
|
||||
return wrapped
|
||||
|
||||
def run_state_reference(f: Callable[..., None]):
|
||||
def wrapped(args):
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args]
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals))
|
||||
jaxpr = hoist_consts_to_refs(jaxpr_)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
||||
out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
|
||||
*consts, *args)
|
||||
_, out_flat = split_list(out_const_flat, [len(consts)])
|
||||
return in_tree.unflatten(out_flat)
|
||||
return wrapped
|
||||
|
@ -18,8 +18,7 @@ from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
|
||||
import jax
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import pretty_printer as pp
|
||||
@ -421,7 +420,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
# `idxs` doesn't include the non indexed dims.
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
else:
|
||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||||
@ -454,7 +453,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
val = batching.moveaxis(val, val_dim, 0)
|
||||
bdim_out = 0
|
||||
@ -487,7 +486,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
val = batching.moveaxis(val, val_dim, 0)
|
||||
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
|
||||
|
@ -20,13 +20,8 @@ from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -107,25 +102,25 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def get(tracer, idx=()):
|
||||
from jax._src.state.primitives import ref_get
|
||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def set(tracer, value, idx=()):
|
||||
from jax._src.state.primitives import ref_set
|
||||
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
def _getitem(self, tracer, idx) -> Array:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
from jax._src.state.primitives import ref_get
|
||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
def _setitem(self, tracer, idx, value) -> None:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
from jax._src.state.primitives import ref_set
|
||||
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
54
jax/_src/state/utils.py
Normal file
54
jax/_src/state/utils.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for tracing stateful functions."""
|
||||
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.state import AbstractRef
|
||||
from jax._src.util import (partition_list, merge_lists, split_list, safe_map,
|
||||
safe_zip)
|
||||
from jax._src.state.primitives import ref_get
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
all_const_avals = [var.aval for var in jaxpr.constvars]
|
||||
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
|
||||
jaxpr.constvars]
|
||||
const_avals_, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
||||
const_avals = map(AbstractRef, const_avals_)
|
||||
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
||||
arg_avals = [var.aval for var in jaxpr.invars]
|
||||
in_avals = [*merged_const_avals, *arg_avals]
|
||||
num_consts = len(merged_const_avals)
|
||||
|
||||
def _hoist(*consts_args):
|
||||
all_consts, args = split_list(consts_args, [num_consts])
|
||||
consts, const_refs = partition_list(is_const_ref, all_consts)
|
||||
# We immediately read the const values out of the `Ref`s.
|
||||
consts = map(lambda x: ref_get(x, ()), consts)
|
||||
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||||
return core.eval_jaxpr(jaxpr, all_consts, *args)
|
||||
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
||||
def val_to_ref_aval(x) -> AbstractRef:
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
if type(aval) is not core.ShapedArray:
|
||||
raise Exception(f"can't make ref from {x}")
|
||||
return AbstractRef(aval)
|
13
tests/BUILD
13
tests/BUILD
@ -998,10 +998,23 @@ jax_test(
|
||||
jax_test(
|
||||
name = "state_test",
|
||||
srcs = ["state_test.py"],
|
||||
# Use fewer cases to prevent timeouts.
|
||||
args = [
|
||||
"--jax_num_generated_cases=5",
|
||||
],
|
||||
backend_variant_args = {
|
||||
"tpu_pjrt_c_api": ["--jax_num_generated_cases=1"],
|
||||
},
|
||||
enable_configs = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 2,
|
||||
"gpu": 2,
|
||||
"tpu": 2,
|
||||
},
|
||||
deps = py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user