mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 23:56:06 +00:00
818 lines
36 KiB
Python
818 lines
36 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Module for discharging state primitives."""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import dataclasses
|
|
from functools import partial
|
|
import operator
|
|
from typing import Any, Protocol
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import api_util
|
|
from jax._src import ad_util
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import linear_util as lu
|
|
from jax._src import source_info_util
|
|
from jax._src import tree_util
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.lax import lax
|
|
from jax._src.lax import slicing as lax_slicing
|
|
from jax._src.state import indexing
|
|
from jax._src.state.types import AbstractRef, RefEffect
|
|
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
|
from jax._src.state.utils import hoist_consts_to_refs
|
|
from jax._src.typing import Array
|
|
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
|
|
partition_list, merge_lists, split_dict)
|
|
|
|
## JAX utilities
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
PyTreeDef = tree_util.PyTreeDef
|
|
|
|
## Discharging state
|
|
|
|
# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values
|
|
# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr
|
|
# into a "pure" jaxpr that takes in and outputs values and no longer has the
|
|
# `Read/Write/Accum` effects.
|
|
|
|
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
|
|
should_discharge: bool | Sequence[bool] = True
|
|
) -> tuple[core.Jaxpr, list[Any]]:
|
|
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
|
|
if isinstance(should_discharge, bool):
|
|
should_discharge = [should_discharge] * len(jaxpr.invars)
|
|
in_avals = [v.aval.inner_aval
|
|
if isinstance(v.aval, AbstractRef) and d
|
|
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
|
|
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
|
|
should_discharge, consts))
|
|
new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
|
|
return new_jaxpr, new_consts
|
|
|
|
@dataclasses.dataclass
|
|
class Environment:
|
|
env: dict[core.Var, Any]
|
|
|
|
def read(self, v: core.Atom) -> Any:
|
|
if type(v) is core.Literal:
|
|
return v.val
|
|
assert isinstance(v, core.Var)
|
|
return self.env[v]
|
|
|
|
def write(self, v: core.Var, val: Any) -> None:
|
|
self.env[v] = val
|
|
|
|
class DischargeRule(Protocol):
|
|
|
|
def __call__(self, in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue], *args: Any,
|
|
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
|
|
...
|
|
|
|
_discharge_rules: dict[core.Primitive, DischargeRule] = {}
|
|
|
|
def register_discharge_rule(prim: core.Primitive):
|
|
def register(f: DischargeRule):
|
|
_discharge_rules[prim] = f
|
|
return register
|
|
|
|
def _eval_jaxpr_discharge_state(
|
|
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
|
|
*args: Any):
|
|
env = Environment({})
|
|
|
|
map(env.write, jaxpr.constvars, consts)
|
|
# Here some args may correspond to `Ref` avals but they'll be treated like
|
|
# regular values in this interpreter.
|
|
map(env.write, jaxpr.invars, args)
|
|
|
|
refs_to_discharge = {id(v.aval) for v, d in zip(jaxpr.invars, should_discharge)
|
|
if d and isinstance(v.aval, AbstractRef)}
|
|
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive is core.mutable_array_p:
|
|
[invar], [outvar] = eqn.invars, eqn.outvars
|
|
ans = env.read(invar)
|
|
refs_to_discharge.add(id(outvar.aval))
|
|
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
|
|
or core.internal_mutable_array_effect in eqn.effects ):
|
|
if eqn.primitive not in _discharge_rules:
|
|
raise NotImplementedError("No state discharge rule implemented for "
|
|
f"primitive: {eqn.primitive}")
|
|
invals = map(env.read, eqn.invars)
|
|
in_avals = [v.aval for v in eqn.invars]
|
|
out_avals = [v.aval for v in eqn.outvars]
|
|
new_invals, ans = _discharge_rules[eqn.primitive](
|
|
in_avals, out_avals, *invals, **eqn.params)
|
|
for new_inval, invar in zip(new_invals, eqn.invars):
|
|
if new_inval is not None:
|
|
env.write(invar, new_inval) # type: ignore[arg-type]
|
|
else:
|
|
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here
|
|
# we assume any higher-order primitives inside of the jaxpr are *not*
|
|
# stateful.
|
|
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
|
ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars),
|
|
**bind_params)
|
|
if eqn.primitive.multiple_results:
|
|
map(env.write, eqn.outvars, ans)
|
|
else:
|
|
env.write(eqn.outvars[0], ans)
|
|
# By convention, we return the outputs of the jaxpr first and then the final
|
|
# values of the `Ref`s. Callers to this function should be able to split
|
|
# them up by looking at `len(jaxpr.outvars)`.
|
|
out_vals = map(env.read, jaxpr.outvars)
|
|
ref_vals = map(
|
|
env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge])
|
|
return out_vals + ref_vals
|
|
|
|
def _is_trivial_indexer(indexer: indexing.NDIndexer):
|
|
for s, idx in zip(indexer.shape, indexer.indices):
|
|
if not isinstance(idx, indexing.Slice):
|
|
return False
|
|
if not isinstance(idx.start, int):
|
|
return False
|
|
if idx.start:
|
|
return False
|
|
if idx.size != s:
|
|
return False
|
|
return True
|
|
|
|
def _convert_to_array_indexer(indexer: indexing.NDIndexer
|
|
) -> tuple[int | Array, ...]:
|
|
# This is the general gather case. We need to create the gather arrays.
|
|
is_integer_indexer, _, integer_indexer = (
|
|
indexing.unpack_ndindexer(indexer)
|
|
)
|
|
total_shape = indexer.get_indexer_shape()
|
|
int_indexer_shape = indexer.int_indexer_shape
|
|
slice_shape = total_shape[len(int_indexer_shape):]
|
|
slice_dims = tuple(
|
|
i + len(int_indexer_shape) for i in range(len(slice_shape))
|
|
)
|
|
slice_dim_iter = iter(slice_dims)
|
|
slice_indexer: list[Array] = []
|
|
for idx, is_int_index in zip(indexer.indices, is_integer_indexer):
|
|
if not is_int_index:
|
|
assert isinstance(idx, indexing.Slice)
|
|
slice_indices = lax.broadcasted_iota(
|
|
np.dtype("int32"), total_shape, next(slice_dim_iter)
|
|
) + idx.start
|
|
slice_indexer.append(slice_indices)
|
|
integer_indexer = tuple(
|
|
lax.expand_dims(idx, (-1,)) for idx in integer_indexer
|
|
)
|
|
continue
|
|
assert next(slice_dim_iter, None) is None
|
|
return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer))
|
|
|
|
|
|
def _maybe_convert_to_dynamic_slice(
|
|
indexer: indexing.NDIndexer,
|
|
) -> (
|
|
tuple[tuple[Array | int, ...], tuple[Array | int, ...], tuple[int, ...]]
|
|
| None
|
|
):
|
|
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
|
|
# if each of the indexers is a `Slice` or a ()-shaped value.
|
|
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
|
|
for i in indexer.indices):
|
|
return None
|
|
# TODO(b/329733289): support strided load/store in interpret mode.
|
|
for i in indexer.indices:
|
|
if isinstance(i, indexing.Slice) and i.stride > 1:
|
|
raise NotImplementedError("Unimplemented stride support.")
|
|
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
|
|
starts = tuple(
|
|
_convert_i32(i.start) if isinstance(i, indexing.Slice)
|
|
else _convert_i32(i) for i in indexer.indices
|
|
)
|
|
sizes = tuple(
|
|
i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices
|
|
)
|
|
squeeze_dims = tuple(
|
|
i
|
|
for i, idx in enumerate(indexer.indices)
|
|
if not isinstance(idx, indexing.Slice)
|
|
)
|
|
return starts, sizes, squeeze_dims
|
|
|
|
|
|
@register_discharge_rule(get_p)
|
|
def _get_discharge_rule(
|
|
in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue], x, *idx,
|
|
tree):
|
|
del in_avals, out_avals
|
|
y = _get_discharge(x, idx, tree)
|
|
return (None,) * (len(idx) + 1), y
|
|
|
|
def _prepend_gather(x, indexer):
|
|
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
|
return x[None][(np.array(0, 'int32'), *indexer)]
|
|
|
|
def _prepend_scatter(x, indexer, val, *, add=False):
|
|
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
|
# However, since this is scatter, we need to remove the 1-sized dimension
|
|
# we added at the front.
|
|
if add:
|
|
return x[None].at[(0, *indexer)].add(val)[0]
|
|
return x[None].at[(0, *indexer)].set(val)[0]
|
|
|
|
|
|
def index_array(x, indexers):
|
|
result = x
|
|
for indexer in indexers:
|
|
if _is_trivial_indexer(indexer):
|
|
continue
|
|
if indexer is None:
|
|
continue
|
|
# If everything in the indexer is a slice or ()-shaped, we can also
|
|
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
|
# We need to squeeze out the 1-sized slices at the end.
|
|
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
|
starts, sizes, squeeze_dims = maybe_slice
|
|
y = lax_slicing.dynamic_slice(result, starts, sizes)
|
|
result = lax.squeeze(y, squeeze_dims)
|
|
else:
|
|
indexer = _convert_to_array_indexer(indexer)
|
|
result = result[None][(np.array(0, "int32"), *indexer)]
|
|
return result
|
|
|
|
def index_swap_array(x, indexers, val):
|
|
result = x
|
|
result_val = val
|
|
for indexer in indexers:
|
|
if _is_trivial_indexer(indexer):
|
|
continue
|
|
# If everything in the indexer is a slice or ()-shaped, we can also
|
|
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
|
# We need to squeeze out the 1-sized slices at the end.
|
|
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
|
starts, sizes, squeeze_dims = maybe_slice
|
|
result_old = lax_slicing.dynamic_slice(result, starts, sizes)
|
|
result_val = lax.expand_dims(result_val, squeeze_dims)
|
|
y = lax_slicing.dynamic_update_slice(result, result_val, starts)
|
|
result = lax.squeeze(result_old, squeeze_dims)
|
|
result_val = y
|
|
else:
|
|
indexer = _convert_to_array_indexer(indexer)
|
|
result_old = _prepend_gather(result, indexer)
|
|
result_val = _prepend_scatter(result, indexer, result_val)
|
|
result = result_old
|
|
return result, result_val
|
|
|
|
def _get_discharge(x, idx, tree):
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
return index_array(x, indexers)
|
|
|
|
def _indexer(idx, indexed_dims):
|
|
idx_ = iter(idx)
|
|
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
|
|
assert next(idx_, None) is None
|
|
return indexer
|
|
|
|
@register_discharge_rule(swap_p)
|
|
def _swap_discharge_rule(
|
|
in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue], x, val, *idx,
|
|
tree):
|
|
del in_avals, out_avals
|
|
z, x_new = _swap_discharge(x, val, idx, tree)
|
|
return (x_new, None) + (None,) * len(idx), z
|
|
|
|
def _swap_discharge(x, val, idx, tree):
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
return index_swap_array(x, indexers, val)
|
|
|
|
@register_discharge_rule(addupdate_p)
|
|
def _addupdate_discharge_rule(
|
|
in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue], x, val, *idx,
|
|
tree):
|
|
del in_avals, out_avals
|
|
ans = _addupdate_discharge(x, val, idx, tree)
|
|
return (ans, None) + (None,) * len(idx), []
|
|
|
|
def _addupdate_discharge(x, val, idx, tree):
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
if len(indexers) > 1:
|
|
raise NotImplementedError("Only single indexer is supported.")
|
|
indexer = indexers[0]
|
|
if _is_trivial_indexer(indexer):
|
|
return x + val
|
|
# If everything in the indexer is a slice or ()-shaped, we can also
|
|
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
|
# We need to squeeze out the 1-sized slices at the end.
|
|
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
|
starts, sizes, squeeze_dims = maybe_slice
|
|
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
|
|
val = lax.expand_dims(val, squeeze_dims)
|
|
y = lax_slicing.dynamic_update_slice(x, x_old + val, starts)
|
|
return y
|
|
indexer = _convert_to_array_indexer(indexer)
|
|
return _prepend_scatter(x, indexer, val, add=True)
|
|
|
|
@weakref_lru_cache
|
|
def _cached_closed_jaxpr_discharge(closed_jaxpr):
|
|
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts
|
|
num_outs = len(jaxpr.outvars)
|
|
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
|
|
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
|
|
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr))
|
|
return discharged_closed_jaxpr, num_outs, fun
|
|
|
|
@register_discharge_rule(core.closed_call_p)
|
|
def _closed_call_discharge_rule(
|
|
in_avals: Sequence[core.AbstractValue], _,*args,
|
|
call_jaxpr: core.ClosedJaxpr):
|
|
discharged_closed_jaxpr, num_outs, fun = _cached_closed_jaxpr_discharge(call_jaxpr)
|
|
out_and_ref_vals = core.closed_call_p.bind(fun, *args,
|
|
call_jaxpr=discharged_closed_jaxpr)
|
|
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
|
|
ref_vals_iter = iter(ref_vals)
|
|
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef)
|
|
else None for aval in in_avals)
|
|
sentinel = object()
|
|
assert next(ref_vals_iter, sentinel) is sentinel
|
|
return new_invals, out_vals
|
|
|
|
# # `run_state`
|
|
|
|
run_state_p = core.Primitive("run_state")
|
|
run_state_p.multiple_results = True
|
|
|
|
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
|
|
which_linear: tuple[bool, ...]):
|
|
if config.enable_checks.value:
|
|
core.check_jaxpr(jaxpr)
|
|
assert len(jaxpr.invars) == len(args)
|
|
assert len(which_linear) == len(args)
|
|
return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr,
|
|
which_linear=which_linear)
|
|
run_state_p.def_custom_bind(_run_state_bind)
|
|
|
|
def _run_state_impl(*args: Any, jaxpr: core.Jaxpr,
|
|
which_linear: tuple[bool, ...]):
|
|
del which_linear
|
|
discharged_jaxpr, consts = discharge_state(jaxpr, ())
|
|
return core.eval_jaxpr(discharged_jaxpr, consts, *args)
|
|
run_state_p.def_impl(_run_state_impl)
|
|
mlir.register_lowering(run_state_p, mlir.lower_fun(_run_state_impl))
|
|
|
|
def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr,
|
|
which_linear: tuple[bool, ...]):
|
|
del which_linear
|
|
# When we abstractly evaluate `run_state`, we want to keep track of which
|
|
# input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to
|
|
# "propagate" out its inner effects. Otherwise, the effects are local to this
|
|
# `run_state`.
|
|
is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)}
|
|
nonlocal_effects = {e for e in jaxpr.effects
|
|
if (isinstance(e, RefEffect) and e.input_index in is_ref)
|
|
or not isinstance(e, RefEffect)}
|
|
return avals, nonlocal_effects
|
|
run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval)
|
|
|
|
def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *,
|
|
jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]):
|
|
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
|
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
|
|
for _ in range(len(nonzero_tangents)):
|
|
_, out_nonzero_tangents = ad.jvp_jaxpr(
|
|
core.ClosedJaxpr(discharged_jaxpr, body_consts),
|
|
nonzero_tangents, instantiate=nonzero_tangents)
|
|
if out_nonzero_tangents == nonzero_tangents:
|
|
break
|
|
nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents)
|
|
else:
|
|
raise Exception("Invalid fixpoint")
|
|
del discharged_jaxpr, body_consts, out_nonzero_tangents
|
|
tangents = [ad.instantiate_zeros(t) if inst else t
|
|
for t, inst in zip(tangents, nonzero_tangents)]
|
|
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
|
closed_jvp_jaxpr, _ = ad.jvp_jaxpr(pe.close_jaxpr(jaxpr),
|
|
nonzero_tangents, [])
|
|
jvp_jaxpr_, jvp_consts = closed_jvp_jaxpr.jaxpr, closed_jvp_jaxpr.consts
|
|
jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_)
|
|
jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents))
|
|
out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr,
|
|
which_linear=jvp_which_linear)
|
|
out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts),
|
|
len(primals)])
|
|
del out_consts
|
|
out_tangents_iter = iter(out_tangents)
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
|
for p, nz in zip(out_primals, nonzero_tangents)]
|
|
return out_primals, out_tangents
|
|
ad.primitive_jvps[run_state_p] = _run_state_jvp
|
|
|
|
_save_everything = lambda *_, **__: True
|
|
|
|
def _convert_outputs_to_writes(
|
|
jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, list[core.ShapedArray]]:
|
|
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
|
|
|
in_avals = [v.aval for v in jaxpr.invars]
|
|
@lu.wrap_init
|
|
def eval_jaxpr(*refs):
|
|
# We split the refs into the original input refs and the dummy residual
|
|
# refs.
|
|
orig_refs, residual_refs = split_list(refs, [len(in_avals)])
|
|
residual_vals = core.eval_jaxpr(jaxpr, (), *orig_refs)
|
|
for res_ref, res_val in zip(residual_refs, residual_vals):
|
|
res_ref[...] = res_val
|
|
return []
|
|
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
|
|
else v.aval for v in jaxpr.outvars]
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
|
eval_jaxpr, [*in_avals, *res_ref_avals])
|
|
assert not consts
|
|
return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
|
|
for a in res_ref_avals]
|
|
|
|
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
|
|
assert not jaxpr.constvars, "Jaxpr should not have constvars"
|
|
|
|
@lu.wrap_init
|
|
def eval_jaxpr(*refs):
|
|
residual_refs, orig_refs = split_list(refs, [num_res])
|
|
residual_vals = [r[...] for r in residual_refs]
|
|
() = core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs)
|
|
return []
|
|
|
|
res_val_avals, orig_ref_avals = \
|
|
split_list([v.aval for v in jaxpr.invars], [num_res])
|
|
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
|
|
aval for aval in res_val_avals]
|
|
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
|
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
|
|
return jaxpr
|
|
|
|
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
|
jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]):
|
|
num_inputs = len(tracers)
|
|
assert num_inputs == len(jaxpr.invars)
|
|
in_unknowns = [not t.pval.is_known() for t in tracers]
|
|
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
|
# after running the for loop. We want to use the jaxpr to determine which
|
|
# `Ref`s are unknown after executing the for loop body given which `Ref`s are
|
|
# unknown before. However, the jaxpr has no outputs. Instead, we discharge
|
|
# the body and run the fixpoint with the discharged jaxpr. We can do this
|
|
# because the outputs of the jaxpr are one-to-one with the inputs.
|
|
discharged_jaxpr_, discharged_consts = discharge_state(jaxpr, ())
|
|
discharged_jaxpr = pe.convert_constvars_jaxpr(discharged_jaxpr_)
|
|
for _ in range(num_inputs):
|
|
jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns
|
|
_, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful(
|
|
discharged_jaxpr, jaxpr_in_unknowns, jaxpr_in_unknowns,
|
|
in_unknowns, False, _save_everything)
|
|
# assert out_inst == out_unknowns
|
|
out_unknowns = list(out_unknowns)
|
|
if out_unknowns == in_unknowns:
|
|
break
|
|
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
|
else:
|
|
raise Exception("Invalid fixpoint")
|
|
del out_unknowns # redundant since it's the same as `in_unknowns`
|
|
tracers = tuple(trace.instantiate_const(t) if uk else t
|
|
for t, uk in zip(tracers, in_unknowns))
|
|
|
|
# We use `partial_eval_jaxpr_stateful` here because it won't remove effectful
|
|
# primitives like `get`/`set`.
|
|
jaxpr_known_resout, jaxpr_unknown_resin_, _, _, num_res_out, num_res_ref = \
|
|
pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, in_inst=in_unknowns,
|
|
ensure_out_unknowns=[], ensure_out_inst=[],
|
|
saveable=_save_everything)
|
|
# # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref`
|
|
# and regular valued input/outputs. However, we'd like to bind these jaxprs to
|
|
# a `for`, which expects only `Ref` inputs and no output. We need to convert
|
|
# both of these jaxprs into ones that are compatible with `for`.
|
|
|
|
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
|
# to output residual values (none of them should be `Ref`s). We'll need to
|
|
# convert the output residual values into `Ref`s that are initially empty
|
|
# `Ref`s that are written to at the end of the jaxpr.
|
|
num_res = num_res_out + num_res_ref
|
|
|
|
num_invars = len(jaxpr_known_resout.invars) - num_res_ref
|
|
_, res_ref_avals = split_list(
|
|
[v.aval for v in jaxpr_known_resout.invars], [num_invars])
|
|
res_avals = [a.inner_aval for a in res_ref_avals] # pytype: disable=attribute-error
|
|
jaxpr_known, new_res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
|
|
# We now run the known jaxpr to obtain our residual values.
|
|
known_tracers, _ = partition_list(in_unknowns, tracers)
|
|
known_which_linear, _ = partition_list(in_unknowns, which_linear)
|
|
known_vals = [t.pval.get_known() for t in known_tracers]
|
|
all_res_avals = [*res_avals, *new_res_avals]
|
|
empty_res = map(ad_util.zeros_like_aval, all_res_avals)
|
|
jaxpr_known_args = [*known_vals, *empty_res]
|
|
|
|
jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res)
|
|
out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known,
|
|
which_linear=jaxpr_known_which_linear)
|
|
known_outputs, residuals = split_list(out_flat, [len(known_tracers)])
|
|
residuals = map(trace.new_instantiated_const, residuals)
|
|
ref_res, nonref_res = split_list(residuals, [num_res_ref])
|
|
|
|
# Now we handle the `jaxpr_unknown` that expects residual values as inputs.
|
|
# This jaxpr is the output of `partial_eval_jaxpr_stateful` that marks which
|
|
# inputs are actually used.
|
|
# `partial_eval_jaxpr_stateful` doesn't remove extra inputs/outputs for you
|
|
# so we use `dce_jaxpr` here to do that.
|
|
# To make it compatible with `for`, we need to convert those residual values
|
|
# into `Ref`s.
|
|
jaxpr_unknown = _convert_inputs_to_reads(len(new_res_avals),
|
|
jaxpr_unknown_resin_)
|
|
_, unknown_tracers = partition_list(in_unknowns, tracers)
|
|
_, uk_which_linear = partition_list(in_unknowns, which_linear)
|
|
unknown_which_linear = (False,) * num_res + tuple(uk_which_linear)
|
|
unknown_inputs = [*nonref_res, *ref_res, *unknown_tracers]
|
|
# Outputs match inputs so we construct output tracers that look like the input
|
|
# tracers.
|
|
res_ref_unknown_outputs = [
|
|
pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None)
|
|
for t in unknown_inputs]
|
|
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
|
|
|
assert len(unknown_inputs) == len(res_ref_unknown_outputs)
|
|
assert len(unknown_inputs) == len(jaxpr_unknown.invars)
|
|
uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear)
|
|
_, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs],
|
|
**uk_params)
|
|
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
|
|
run_state_p, uk_params,
|
|
eqn_effects, source)
|
|
for t in res_ref_unknown_outputs: t.recipe = eqn
|
|
_, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res])
|
|
return merge_lists(in_unknowns, known_outputs, unknown_outputs)
|
|
pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval
|
|
|
|
def _run_state_partial_eval_custom(
|
|
saveable: Callable[..., pe.RematCases_],
|
|
in_unknowns: Sequence[bool],
|
|
in_inst: Sequence[bool],
|
|
eqn: core.JaxprEqn):
|
|
if not any(in_unknowns):
|
|
return eqn, None, in_unknowns, [False] * len(in_unknowns), []
|
|
jaxpr, which_linear = split_dict(eqn.params, ["jaxpr", "which_linear"])
|
|
num_inputs = len(eqn.invars)
|
|
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
|
# after running the for loop. However, the jaxpr has no outputs. Instead, we
|
|
# discharge the body and run the fixpoint with the discharged jaxpr. We can do
|
|
# this because the outputs of the discharged jaxpr are one-to-one with the
|
|
# inputs.
|
|
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
|
discharged_jaxpr = discharged_jaxpr.replace(
|
|
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
|
|
constvars=[])
|
|
in_unknowns, in_inst = list(in_unknowns), list(in_inst)
|
|
out_unknowns, out_inst = in_unknowns, in_unknowns
|
|
for _ in range(num_inputs):
|
|
jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns
|
|
_, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful(
|
|
discharged_jaxpr,
|
|
in_unknowns=jaxpr_in_unknowns,
|
|
in_inst=jaxpr_in_unknowns,
|
|
ensure_out_unknowns=in_unknowns,
|
|
ensure_out_inst=in_unknowns,
|
|
saveable=saveable)
|
|
out_unknowns = list(out_unknowns)
|
|
if out_unknowns == in_unknowns:
|
|
break
|
|
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
|
else:
|
|
if num_inputs > 0: raise Exception("Invalid fixpoint")
|
|
del out_unknowns # Redundant since it's the same as `in_unknowns`
|
|
new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst)
|
|
if type(x) is core.Var and inst and not already]
|
|
|
|
# We use `partial_eval_jaxpr_stateful` here because it won't remove effectful
|
|
# primitives like `get`/`set`.
|
|
jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res_out, num_res_ref = \
|
|
pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns,
|
|
in_unknowns, [], [], saveable)
|
|
num_res = num_res_ref + num_res_out
|
|
# `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` and
|
|
# non-Ref input/outputs. However, we'd like to bind these jaxprs to a
|
|
# `for`, which expects only `Ref` inputs and no output. We need to convert
|
|
# both of these jaxprs into ones that are compatible with `for`.
|
|
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
|
|
|
|
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
|
# to output residual values (none of them should be `Ref`s). We'll need to
|
|
# convert the output residual values into `Ref`s that are initially empty
|
|
# `Ref`s that are written to at the end of the jaxpr.
|
|
jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
|
|
|
|
# In a stateful partial_eval, the residuals should be `Ref`s.
|
|
res_avals = map(AbstractRef, res_avals) # type: ignore
|
|
|
|
known_invars, staged_invars = partition_list(in_unknowns, eqn.invars)
|
|
known_outvars, staged_outvars = partition_list(in_unknowns, eqn.outvars)
|
|
newvar = core.gensym()
|
|
_, res_ref_avals = split_list([v.aval for v in jaxpr_known_resout.invars],
|
|
[len(known_invars)])
|
|
nonref_resvars = map(newvar, res_avals)
|
|
ref_resvars = map(newvar, res_ref_avals)
|
|
known_out_resvars = map(newvar, [*res_ref_avals, *res_avals])
|
|
|
|
known_which_linear, _ = partition_list(in_unknowns, which_linear)
|
|
jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res)
|
|
known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars]
|
|
|
|
known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear)
|
|
_, known_effects = run_state_p.abstract_eval(
|
|
*[v.aval for v in known_and_res_invars], **known_params)
|
|
eqn_known = pe.new_jaxpr_eqn(known_and_res_invars,
|
|
[*known_outvars, *known_out_resvars],
|
|
run_state_p, known_params,
|
|
known_effects, eqn.source_info)
|
|
|
|
jaxpr_staged = _convert_inputs_to_reads(len(res_avals), jaxpr_staged_resin_)
|
|
|
|
_, staged_which_linear = partition_list(in_unknowns, which_linear)
|
|
which_linear_unknown = (*[False] * num_res, *staged_which_linear)
|
|
staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown)
|
|
rejiggered_resvars = [*nonref_resvars, *ref_resvars]
|
|
_, staged_invars = partition_list(in_unknowns, eqn.invars)
|
|
res_staged_invars = [*rejiggered_resvars, *staged_invars]
|
|
_, staged_effects = run_state_p.abstract_eval(
|
|
*[v.aval for v in res_staged_invars], **staged_params)
|
|
_, staged_outvars = partition_list(in_unknowns, eqn.outvars)
|
|
if num_res:
|
|
@lu.wrap_init
|
|
def staged(*args):
|
|
out = run_state_p.bind(*args, **staged_params)
|
|
return out[num_res:]
|
|
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
|
|
[v.aval for v in res_staged_invars])
|
|
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
|
|
staged_outvars,
|
|
core.closed_call_p,
|
|
dict(call_jaxpr=pe.close_jaxpr(staged_call_jaxpr)),
|
|
staged_effects, eqn.source_info)
|
|
assert len(res_staged_invars) == len(staged_call_jaxpr.invars)
|
|
assert len(staged_outvars) == len(staged_call_jaxpr.outvars)
|
|
else:
|
|
eqn_staged = pe.new_jaxpr_eqn(staged_invars,
|
|
staged_outvars,
|
|
run_state_p,
|
|
staged_params,
|
|
staged_effects, eqn.source_info)
|
|
new_vars = [*new_inst, *nonref_resvars, *ref_resvars]
|
|
return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars
|
|
pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom
|
|
|
|
def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool]
|
|
) -> tuple[core.Jaxpr, Any]:
|
|
def trans(*args):
|
|
# First we want to run the computation to read all the residual refs. We can
|
|
# do that by using partial evaluation with all linear inputs unknown.
|
|
res_jaxpr_, tangent_jaxpr_, *_, num_res_out, num_res_ref = \
|
|
pe.partial_eval_jaxpr_stateful(jaxpr, which_linear, in_inst=which_linear,
|
|
ensure_out_inst=[],
|
|
ensure_out_unknowns=[],
|
|
saveable=_save_everything)
|
|
|
|
num_unknown = sum(which_linear)
|
|
num_known = len(jaxpr.invars) - num_unknown
|
|
res_args, _ = partition_list(which_linear, args)
|
|
res_jaxpr_avals = [v.aval for v in res_jaxpr_.invars]
|
|
_, res_avals = split_list(res_jaxpr_avals, [num_known])
|
|
res_avals = [a.inner_aval for a in res_avals] # pytype: disable=attribute-error
|
|
all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]]
|
|
empty_res = map(ad.zeros_like_aval, all_avals)
|
|
res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_)
|
|
res = run_state_p.bind(*res_args, *empty_res, jaxpr=res_jaxpr,
|
|
which_linear=(False,) * (len(res_args) + len(empty_res)))
|
|
res = res[len(res_args):]
|
|
ref_res_, nonref_res_ = split_list(res, [num_res_ref])
|
|
|
|
# Now that we have residual values, we run the tangent jaxpr. It takes as
|
|
# input the residuals, the loop index, and all the refs (at least, the ones
|
|
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
|
|
# unknown inputs!
|
|
tangent_jaxpr, used_inputs = pe.dce_jaxpr(tangent_jaxpr_, [])
|
|
used_res, used_cts = split_list(used_inputs, [len(res)])
|
|
used_nonref_res, used_ref_res = split_list(used_res, [num_res_out])
|
|
_, nonref_res = partition_list(used_nonref_res, nonref_res_)
|
|
_, ref_res = partition_list(used_ref_res, ref_res_)
|
|
primals_args = [*nonref_res, *ref_res]
|
|
_, tangent_args = partition_list(which_linear, args)
|
|
_, ct_args = partition_list(used_cts, tangent_args)
|
|
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
|
|
return []
|
|
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
|
return jaxpr_trans, consts
|
|
|
|
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,
|
|
which_linear: tuple[bool, ...]):
|
|
# if any in_ct is nonzero, we definitely want it in args_ (and the
|
|
# corresponding x in args could be an undefined primal, but doesn't have to be)
|
|
# for non-res stuff:
|
|
# getting and setting => (nonzero ct, UndefinedPrimal arg)
|
|
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
|
|
# just getting => (zero ct , UndefinedPrimal arg)
|
|
# for res stuff:
|
|
# (zero ct , not UndefinedPrimal)
|
|
assert any(which_linear)
|
|
transpose_args = []
|
|
for x, ct in zip(args, in_cts):
|
|
if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x):
|
|
# this is a residual, take x!
|
|
transpose_args.append(x)
|
|
elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x):
|
|
# the loop was 'just getting', plug in a zero
|
|
transpose_args.append(ad_util.zeros_like_aval(x.aval))
|
|
elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x):
|
|
# the loop was 'just setting', grab that cotangent! x is dummy
|
|
transpose_args.append(ct)
|
|
elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x):
|
|
# the loop was 'getting and setting', grab that cotangent!
|
|
transpose_args.append(ct)
|
|
jaxpr_transpose_, consts = _transpose_jaxpr(jaxpr, which_linear)
|
|
jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_)
|
|
which_linear = (*[False] * len(consts), *which_linear)
|
|
const_all_outs = run_state_p.bind(*consts, *transpose_args,
|
|
jaxpr=jaxpr_transpose,
|
|
which_linear=which_linear)
|
|
_, all_outs = split_list(const_all_outs, [len(consts)])
|
|
ct_outs = [ct if ad.is_undefined_primal(x) else None
|
|
for x, ct in zip(args, all_outs)]
|
|
return ct_outs
|
|
ad.primitive_transposes[run_state_p] = _run_state_transpose
|
|
|
|
@register_discharge_rule(run_state_p)
|
|
def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue],
|
|
*args: Any, jaxpr: core.Jaxpr,
|
|
which_linear: Sequence[bool]):
|
|
del out_avals
|
|
out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear)
|
|
new_invals = []
|
|
for aval, out_val in zip(in_avals, out_vals):
|
|
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
|
|
return new_invals, out_vals
|
|
|
|
def initial_style_jaxpr(
|
|
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
|
|
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
|
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
|
|
|
|
@weakref_lru_cache
|
|
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
|
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
|
tree_util.treedef_tuple((in_tree,)))
|
|
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
|
return jaxpr, consts, out_tree_thunk()
|
|
|
|
def run_state(f: Callable[..., None]):
|
|
def wrapped(args):
|
|
flat_args, in_tree = tree_util.tree_flatten(args)
|
|
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args]
|
|
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals))
|
|
jaxpr = hoist_consts_to_refs(jaxpr_)
|
|
which_linear = (False,) * (len(consts) + len(flat_args))
|
|
out_const_flat = run_state_p.bind(*consts, *flat_args, jaxpr=jaxpr,
|
|
which_linear=which_linear)
|
|
_, out_flat = split_list(out_const_flat, [len(consts)])
|
|
return in_tree.unflatten(out_flat)
|
|
return wrapped
|
|
|
|
def run_state_reference(f: Callable[..., None]):
|
|
def wrapped(args):
|
|
flat_args, in_tree = tree_util.tree_flatten(args)
|
|
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args]
|
|
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals))
|
|
jaxpr = hoist_consts_to_refs(jaxpr_)
|
|
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
|
out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
|
|
*consts, *args)
|
|
_, out_flat = split_list(out_const_flat, [len(consts)])
|
|
return in_tree.unflatten(out_flat)
|
|
return wrapped
|