mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #11701 from sharadmv:state
PiperOrigin-RevId: 464658336
This commit is contained in:
commit
01819257f6
@ -15,7 +15,7 @@
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar
|
||||
from typing import Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
@ -29,8 +29,8 @@ from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
|
||||
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
|
||||
from jax._src import ad_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list)
|
||||
import jax.numpy as jnp
|
||||
@ -49,350 +49,13 @@ T = TypeVar('T')
|
||||
class Ref(Generic[T]): pass
|
||||
Array = Any
|
||||
|
||||
## State effect
|
||||
StateEffect = state.StateEffect
|
||||
ShapedArrayRef = state.ShapedArrayRef
|
||||
ref_set = state.ref_set
|
||||
ref_get = state.ref_get
|
||||
ref_addupdate = state.ref_addupdate
|
||||
discharge_state = state.discharge_state
|
||||
|
||||
class StateEffect: pass
|
||||
State = StateEffect()
|
||||
|
||||
## get/swap/addupdate implementations
|
||||
|
||||
# `get` reads a value from a `Ref` type, a.k.a.:
|
||||
# a = get_p.bind(x)
|
||||
# or we can read using indices:
|
||||
# a = get_p.bind(x, 0, 1)
|
||||
# Staging out `a = get_p.bind(x)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# a:f32[3] <- x[]
|
||||
get_p = core.Primitive("get")
|
||||
|
||||
def _get_impl(ref: Ref, *idx: int):
|
||||
del ref, idx
|
||||
raise ValueError("Can't evaluate `get` outside a stateful context.")
|
||||
get_p.def_impl(_get_impl)
|
||||
|
||||
def ref_get(ref: Ref, idx: Tuple[int]) -> Array:
|
||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||
idx = map(jnp.int32, idx)
|
||||
return get_p.bind(ref, *idx)
|
||||
|
||||
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||||
# b = swap_p.bind(x, a)
|
||||
# It generalizes the setting operation for a `Ref` as we can ignore the return
|
||||
# value:
|
||||
# _ = swap_p.bind(x, a)
|
||||
# `swap_p` also takes in index arguments following the value, i.e.:
|
||||
# _ = swap_p.bind(x, a, 0, 1)
|
||||
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is
|
||||
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# b:f32[3], x:Ref{f32[3]} <- x, a
|
||||
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is
|
||||
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
|
||||
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
|
||||
# x:Ref{f32[3]}[i, j] <- a
|
||||
swap_p = core.Primitive("swap")
|
||||
|
||||
def _swap_impl(ref: Ref, value: Array, *idx: int):
|
||||
del ref, idx, value
|
||||
raise ValueError("Can't evaluate `swap` outside a stateful context.")
|
||||
swap_p.def_impl(_swap_impl)
|
||||
|
||||
def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array:
|
||||
"""Sets a `Ref`'s value and returns the original value."""
|
||||
idx = map(jnp.int32, idx)
|
||||
return swap_p.bind(ref, value, *idx)
|
||||
|
||||
def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None:
|
||||
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||||
ref_swap(ref, idx, value)
|
||||
|
||||
|
||||
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
||||
# Semantically,
|
||||
# ```
|
||||
# addupdate ref a *idx
|
||||
# ```
|
||||
# is equivalent to
|
||||
# ```
|
||||
# b = get ref *idx
|
||||
# c = add b x
|
||||
# _ = swap ref c *idx
|
||||
# ```
|
||||
addupdate_p = core.Primitive('addupdate')
|
||||
addupdate_p.multiple_results = True
|
||||
|
||||
def _addupdate_impl(ref: Ref, value: Array, *idx: int):
|
||||
del ref, idx, value
|
||||
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||||
addupdate_p.def_impl(_addupdate_impl)
|
||||
|
||||
def ref_addupdate(ref: Ref, idx: Tuple[int], x: Array) -> None:
|
||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||
return addupdate_p.bind(ref, x, *idx)
|
||||
|
||||
## get/set/addupdate abstract evaluation rules
|
||||
|
||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||
# A `ShapedArrayRef` is a abstract value for mutable containers of array types
|
||||
class ShapedArrayRef(core.AbstractValue):
|
||||
__slots__ = ["shape", "dtype"]
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
def join(self, other):
|
||||
assert core.symbolic_equal_shape(self.shape, other.shape)
|
||||
assert self.dtype == other.dtype
|
||||
return self
|
||||
|
||||
def _getitem(self, tracer, idx) -> Array:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
def _setitem(self, tracer, idx, val) -> None:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
return ref_set(tracer, idx, val)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
a = core.ShapedArray(self.shape, self.dtype)
|
||||
return f'Ref{{{a.str_short()}}}'
|
||||
|
||||
def at_least_vspace(self):
|
||||
return self
|
||||
|
||||
core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval
|
||||
|
||||
def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx: int):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
||||
return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State}
|
||||
get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||||
|
||||
|
||||
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
|
||||
*idx: int):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
expected_output_shape = ref_aval.shape[len(idx):]
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State}
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
|
||||
def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: int):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
expected_output_shape = ref_aval.shape[len(idx):]
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
return [], {State}
|
||||
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||||
|
||||
## Pretty printing for `get` and `swap` in jaxprs
|
||||
|
||||
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
||||
foreground=pp.Color.GREEN)
|
||||
|
||||
def _get_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = get x i` as `a <- x[i]`
|
||||
y, = eqn.outvars
|
||||
x, *idx = eqn.invars
|
||||
idx = ','.join(core.pp_var(i, context) for i in idx)
|
||||
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return [lhs, pp.text(' <- '), pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']')
|
||||
]))]
|
||||
core.pp_eqn_rules[get_p] = _get_pp_rule
|
||||
|
||||
def _swap_pp_rule(eqn, context, settings):
|
||||
y, = eqn.outvars
|
||||
x, v, *idx = eqn.invars
|
||||
idx = ','.join(core.pp_var(i, context) for i in idx)
|
||||
if type(y) is core.DropVar:
|
||||
# In the case of a set (ignored return value),
|
||||
# pretty print `_ = swap x v i` as `x[i] <- v`
|
||||
del y
|
||||
return [
|
||||
pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), pp.text(idx), pp.text(']')
|
||||
])), pp.text(' <- '), pp.text(core.pp_var(v, context))]
|
||||
else:
|
||||
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
|
||||
x_i = pp.concat([pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), pp.text(idx), pp.text(']')])
|
||||
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return [y, pp.text(', '), x_i, pp.text(' <- '),
|
||||
x_i, pp.text(', '), pp.text(core.pp_var(v, context))]
|
||||
core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||
|
||||
def _addupdate_pp_rule(eqn, context, settings):
|
||||
# pretty-print ` = addupdate x i v` as `x[i] += v`
|
||||
() = eqn.outvars
|
||||
x, v, *idx = eqn.invars
|
||||
idx = ','.join(core.pp_var(i, context) for i in idx)
|
||||
return [
|
||||
pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), pp.text(idx), pp.text(']')
|
||||
])), pp.text(' += '), pp.text(core.pp_var(v, context))]
|
||||
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||||
|
||||
## get/swap/addupdate JVP rules
|
||||
|
||||
def _get_jvp(primals: List[Any], tangents: List[Any]):
|
||||
ref_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, ShapedArrayRef)
|
||||
ref_tangent, *_ = tangents
|
||||
assert isinstance(ref_tangent.aval, ShapedArrayRef)
|
||||
return ref_get(ref_primal, idx), ref_get(ref_tangent, idx) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[get_p] = _get_jvp
|
||||
|
||||
def _swap_jvp(primals: List[Any], tangents: List[Any]):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, ShapedArrayRef)
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
assert isinstance(ref_tangent.aval, ShapedArrayRef)
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type]
|
||||
ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[swap_p] = _swap_jvp
|
||||
|
||||
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
addupdate_p.bind(ref_primal, x_primal, *idx)
|
||||
addupdate_p.bind(ref_tangent, x_tangent, *idx)
|
||||
return [], []
|
||||
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
|
||||
|
||||
## get/swap/addupdate transpose rules
|
||||
|
||||
def _get_transpose(g, ref, *idx):
|
||||
# get transpose is addupdate
|
||||
if type(g) is not ad_util.Zero:
|
||||
ref_addupdate(ref, idx, g)
|
||||
return [None] + [None] * len(idx)
|
||||
ad.primitive_transposes[get_p] = _get_transpose
|
||||
|
||||
def _swap_transpose(g, ref, x, *idx):
|
||||
# swap transpose is swap
|
||||
x_bar = ref_swap(ref, idx, ad_util.instantiate(g))
|
||||
return [None, x_bar] + [None] * len(idx)
|
||||
ad.primitive_transposes[swap_p] = _swap_transpose
|
||||
|
||||
|
||||
## Discharging state
|
||||
|
||||
# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values
|
||||
# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr
|
||||
# into a "pure" jaxpr that takes in and outputs values and no longer has the
|
||||
# `State` effect.
|
||||
|
||||
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]:
|
||||
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
|
||||
in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype)
|
||||
if type(v.aval) is ShapedArrayRef
|
||||
else v.aval for v in jaxpr.invars]
|
||||
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, consts))
|
||||
new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
|
||||
return new_jaxpr, new_consts
|
||||
|
||||
def _dynamic_index(x, idx):
|
||||
if not idx: return x
|
||||
ndim = len(x.shape)
|
||||
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
|
||||
sizes = (1,) * len(idx) + x.shape[len(idx):]
|
||||
out = lax.dynamic_slice(x, starts, sizes)
|
||||
return out.reshape(x.shape[len(idx):])
|
||||
|
||||
def _dynamic_update_index(x, idx, val):
|
||||
if not idx: return val
|
||||
ndim = len(x.shape)
|
||||
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
|
||||
update = val.reshape((1,) * len(idx) + x.shape[len(idx):])
|
||||
return lax.dynamic_update_slice(x, update, starts)
|
||||
|
||||
def _eval_jaxpr_discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||
*args: Any):
|
||||
env: Dict[core.Var, Any] = {}
|
||||
|
||||
def read(v: core.Atom) -> Any:
|
||||
if type(v) is core.Literal:
|
||||
return v.val
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
|
||||
def write(v: core.Var, val: Any) -> None:
|
||||
env[v] = val
|
||||
|
||||
map(write, jaxpr.constvars, consts)
|
||||
# Here some args may correspond to `Ref` avals but they'll be treated like
|
||||
# regular values in this interpreter.
|
||||
map(write, jaxpr.invars, args)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vals = map(read, eqn.invars)
|
||||
if eqn.primitive is get_p:
|
||||
# `y <- x[i]` becomes `y = ds x i`
|
||||
x, *idx = in_vals
|
||||
write(eqn.outvars[0], _dynamic_index(x, idx))
|
||||
elif eqn.primitive is swap_p:
|
||||
# `z, x[i] <- x[i], val` becomes:
|
||||
# z = ds x i
|
||||
# x = dus x i val
|
||||
x, val, *idx = in_vals
|
||||
write(eqn.outvars[0], _dynamic_index(x, idx))
|
||||
assert isinstance(eqn.invars[0], core.Var)
|
||||
write(eqn.invars[0], _dynamic_update_index(x, idx, val))
|
||||
elif eqn.primitive is addupdate_p:
|
||||
# `x[i] += val` becomes:
|
||||
# y = ds x i
|
||||
# z = y + val
|
||||
# x = dus x i z
|
||||
x, val, *idx = in_vals
|
||||
ans = _dynamic_update_index(x, idx, val + _dynamic_index(x, idx))
|
||||
assert isinstance(eqn.invars[0], core.Var)
|
||||
write(eqn.invars[0], ans)
|
||||
else:
|
||||
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here
|
||||
# we assume any higher-order primitives inside of the jaxpr are *not*
|
||||
# stateful.
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
# By convention, we return the outputs of the jaxpr first and then the final
|
||||
# values of the `Ref`s. Callers to this function should be able to split
|
||||
# them up by looking at `len(jaxpr.outvars)`.
|
||||
out_vals = map(read, jaxpr.outvars)
|
||||
ref_vals = map(
|
||||
read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef])
|
||||
return out_vals + ref_vals
|
||||
|
||||
## `for_loop` implementation
|
||||
|
||||
|
19
jax/_src/state/__init__.py
Normal file
19
jax/_src/state/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for state."""
|
||||
from jax._src.state.types import ShapedArrayRef, StateEffect
|
||||
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
|
||||
ref_addupdate, get_p, swap_p,
|
||||
addupdate_p)
|
||||
from jax._src.state.discharge import discharge_state
|
121
jax/_src/state/discharge.py
Normal file
121
jax/_src/state/discharge.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for discharging state primitives."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from jax._src.state.types import ShapedArrayRef
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
## Discharging state
|
||||
|
||||
# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values
|
||||
# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr
|
||||
# into a "pure" jaxpr that takes in and outputs values and no longer has the
|
||||
# `StateEffect` effect.
|
||||
|
||||
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]:
|
||||
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
|
||||
in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype)
|
||||
if type(v.aval) is ShapedArrayRef
|
||||
else v.aval for v in jaxpr.invars]
|
||||
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, consts))
|
||||
new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
|
||||
return new_jaxpr, new_consts
|
||||
|
||||
def _dynamic_index(x, idx):
|
||||
if not idx: return x
|
||||
ndim = len(x.shape)
|
||||
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
|
||||
sizes = (1,) * len(idx) + x.shape[len(idx):]
|
||||
out = lax.dynamic_slice(x, starts, sizes)
|
||||
return out.reshape(x.shape[len(idx):])
|
||||
|
||||
def _dynamic_update_index(x, idx, val):
|
||||
if not idx: return val
|
||||
ndim = len(x.shape)
|
||||
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
|
||||
update = val.reshape((1,) * len(idx) + x.shape[len(idx):])
|
||||
return lax.dynamic_update_slice(x, update, starts)
|
||||
|
||||
def _eval_jaxpr_discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||
*args: Any):
|
||||
env: Dict[core.Var, Any] = {}
|
||||
|
||||
def read(v: core.Atom) -> Any:
|
||||
if type(v) is core.Literal:
|
||||
return v.val
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
|
||||
def write(v: core.Var, val: Any) -> None:
|
||||
env[v] = val
|
||||
|
||||
map(write, jaxpr.constvars, consts)
|
||||
# Here some args may correspond to `Ref` avals but they'll be treated like
|
||||
# regular values in this interpreter.
|
||||
map(write, jaxpr.invars, args)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vals = map(read, eqn.invars)
|
||||
if eqn.primitive is get_p:
|
||||
# `y <- x[i]` becomes `y = ds x i`
|
||||
x, *idx = in_vals
|
||||
write(eqn.outvars[0], _dynamic_index(x, idx))
|
||||
elif eqn.primitive is swap_p:
|
||||
# `z, x[i] <- x[i], val` becomes:
|
||||
# z = ds x i
|
||||
# x = dus x i val
|
||||
x, val, *idx = in_vals
|
||||
write(eqn.outvars[0], _dynamic_index(x, idx))
|
||||
assert isinstance(eqn.invars[0], core.Var)
|
||||
write(eqn.invars[0], _dynamic_update_index(x, idx, val))
|
||||
elif eqn.primitive is addupdate_p:
|
||||
# `x[i] += val` becomes:
|
||||
# y = ds x i
|
||||
# z = y + val
|
||||
# x = dus x i z
|
||||
x, val, *idx = in_vals
|
||||
ans = _dynamic_update_index(x, idx, val + _dynamic_index(x, idx))
|
||||
assert isinstance(eqn.invars[0], core.Var)
|
||||
write(eqn.invars[0], ans)
|
||||
else:
|
||||
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here
|
||||
# we assume any higher-order primitives inside of the jaxpr are *not*
|
||||
# stateful.
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
# By convention, we return the outputs of the jaxpr first and then the final
|
||||
# values of the `Ref`s. Callers to this function should be able to split
|
||||
# them up by looking at `len(jaxpr.outvars)`.
|
||||
out_vals = map(read, jaxpr.outvars)
|
||||
ref_vals = map(
|
||||
read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef])
|
||||
return out_vals + ref_vals
|
254
jax/_src/state/primitives.py
Normal file
254
jax/_src/state/primitives.py
Normal file
@ -0,0 +1,254 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for state primitives."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, Generic, List, Tuple, TypeVar
|
||||
|
||||
from jax import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax.interpreters import ad
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.state.types import ShapedArrayRef, StateEffect
|
||||
|
||||
## General utilities
|
||||
|
||||
Array = Any
|
||||
T = TypeVar('T')
|
||||
class Ref(Generic[T]): pass
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
## get/swap/addupdate implementations
|
||||
|
||||
# `get` reads a value from a `Ref` type, a.k.a.:
|
||||
# a = get_p.bind(x)
|
||||
# or we can read using indices:
|
||||
# a = get_p.bind(x, 0, 1)
|
||||
# Staging out `a = get_p.bind(x)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# a:f32[3] <- x[]
|
||||
get_p = core.Primitive("get")
|
||||
|
||||
def _get_impl(ref: Ref, *idx: int):
|
||||
del ref, idx
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
get_p.def_impl(_get_impl)
|
||||
|
||||
def ref_get(ref: Ref, idx: Tuple[int]) -> Array:
|
||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||
idx = map(jnp.int32, idx)
|
||||
return get_p.bind(ref, *idx)
|
||||
|
||||
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||||
# b = swap_p.bind(x, a)
|
||||
# It generalizes the setting operation for a `Ref` as we can ignore the return
|
||||
# value:
|
||||
# _ = swap_p.bind(x, a)
|
||||
# `swap_p` also takes in index arguments following the value, i.e.:
|
||||
# _ = swap_p.bind(x, a, 0, 1)
|
||||
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is
|
||||
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# b:f32[3], x:Ref{f32[3]} <- x, a
|
||||
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is
|
||||
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
|
||||
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
|
||||
# x:Ref{f32[3]}[i, j] <- a
|
||||
swap_p = core.Primitive("swap")
|
||||
|
||||
def _swap_impl(ref: Ref, value: Array, *idx: int):
|
||||
del ref, value, idx
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
swap_p.def_impl(_swap_impl)
|
||||
|
||||
def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array:
|
||||
"""Sets a `Ref`'s value and returns the original value."""
|
||||
idx = map(jnp.int32, idx)
|
||||
return swap_p.bind(ref, value, *idx)
|
||||
|
||||
def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None:
|
||||
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||||
ref_swap(ref, idx, value)
|
||||
|
||||
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
||||
# Semantically,
|
||||
# ```
|
||||
# addupdate ref a *idx
|
||||
# ```
|
||||
# is equivalent to
|
||||
# ```
|
||||
# b = get ref *idx
|
||||
# c = add b x
|
||||
# _ = swap ref c *idx
|
||||
# ```
|
||||
addupdate_p = core.Primitive('addupdate')
|
||||
addupdate_p.multiple_results = True
|
||||
|
||||
def _addupdate_impl(ref: Ref, value: Array, *idx: int):
|
||||
del ref, idx, value
|
||||
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||||
addupdate_p.def_impl(_addupdate_impl)
|
||||
|
||||
def ref_addupdate(ref: Ref, idx: Tuple[int], x: Array) -> None:
|
||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||
return addupdate_p.bind(ref, x, *idx)
|
||||
|
||||
## get/set/addupdate abstract evaluation rules
|
||||
|
||||
def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx: int):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
||||
return (core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype),
|
||||
{StateEffect})
|
||||
get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||||
|
||||
|
||||
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
|
||||
*idx: int):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
expected_output_shape = ref_aval.shape[len(idx):]
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
return (core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype),
|
||||
{StateEffect})
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
|
||||
def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: int):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
expected_output_shape = ref_aval.shape[len(idx):]
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
return [], {StateEffect}
|
||||
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||||
|
||||
## Pretty printing for `get` and `swap` in jaxprs
|
||||
|
||||
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
||||
foreground=pp.Color.GREEN)
|
||||
|
||||
def _get_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = get x i` as `x[i] <- a`
|
||||
y, = eqn.outvars
|
||||
x, *idx = eqn.invars
|
||||
idx = ','.join(core.pp_var(i, context) for i in idx)
|
||||
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return [lhs, pp.text(' <- '), pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']')
|
||||
]))]
|
||||
core.pp_eqn_rules[get_p] = _get_pp_rule
|
||||
|
||||
def _swap_pp_rule(eqn, context, settings):
|
||||
y, = eqn.outvars
|
||||
x, v, *idx = eqn.invars
|
||||
idx = ','.join(core.pp_var(i, context) for i in idx)
|
||||
if type(y) is core.DropVar:
|
||||
# In the case of a set (ignored return value),
|
||||
# pretty print `_ = swap x v i` as `x[i] <- v`
|
||||
del y
|
||||
return [
|
||||
pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), pp.text(idx), pp.text(']')
|
||||
])), pp.text(' <- '), pp.text(core.pp_var(v, context))]
|
||||
else:
|
||||
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
|
||||
x_i = pp.concat([pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), pp.text(idx), pp.text(']')])
|
||||
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return [y, pp.text(', '), x_i, pp.text(' <- '),
|
||||
x_i, pp.text(', '), pp.text(core.pp_var(v, context))]
|
||||
core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||
|
||||
def _addupdate_pp_rule(eqn, context, settings):
|
||||
# pretty-print ` = addupdate x i v` as `x[i] += v`
|
||||
() = eqn.outvars
|
||||
x, v, *idx = eqn.invars
|
||||
idx = ','.join(core.pp_var(i, context) for i in idx)
|
||||
return [
|
||||
pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), pp.text(idx), pp.text(']')
|
||||
])), pp.text(' += '), pp.text(core.pp_var(v, context))]
|
||||
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||||
|
||||
## get/swap/addupdate JVP rules
|
||||
|
||||
def _get_jvp(primals: List[Any], tangents: List[Any]):
|
||||
ref_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, ShapedArrayRef)
|
||||
ref_tangent, *_ = tangents
|
||||
assert isinstance(ref_tangent.aval, ShapedArrayRef)
|
||||
return ref_get(ref_primal, idx), ref_get(ref_tangent, idx) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[get_p] = _get_jvp
|
||||
|
||||
def _swap_jvp(primals: List[Any], tangents: List[Any]):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, ShapedArrayRef)
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
assert isinstance(ref_tangent.aval, ShapedArrayRef)
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type]
|
||||
ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[swap_p] = _swap_jvp
|
||||
|
||||
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
addupdate_p.bind(ref_primal, x_primal, *idx)
|
||||
addupdate_p.bind(ref_tangent, x_tangent, *idx)
|
||||
return [], []
|
||||
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
|
||||
|
||||
## get/swap/addupdate transpose rules
|
||||
|
||||
def _get_transpose(g, ref, *idx):
|
||||
# get transpose is addupdate
|
||||
if type(g) is not ad_util.Zero:
|
||||
ref_addupdate(ref, idx, g)
|
||||
return [None] + [None] * len(idx)
|
||||
ad.primitive_transposes[get_p] = _get_transpose
|
||||
|
||||
def _swap_transpose(g, ref, x, *idx):
|
||||
# swap transpose is swap
|
||||
x_bar = ref_swap(ref, idx, ad_util.instantiate(g))
|
||||
return [None, x_bar] + [None] * len(idx)
|
||||
ad.primitive_transposes[swap_p] = _swap_transpose
|
107
jax/_src/state/types.py
Normal file
107
jax/_src/state/types.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for state types."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from jax import api_util
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
import numpy as np
|
||||
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Array = Any
|
||||
|
||||
class _StateEffect:
|
||||
def __repr__(self):
|
||||
return "State"
|
||||
__str__ = __repr__
|
||||
StateEffect = _StateEffect()
|
||||
|
||||
# ## `Ref`s
|
||||
|
||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||
# A `ShapedArrayRef` is a abstract value for mutable containers of array types
|
||||
class ShapedArrayRef(core.AbstractValue):
|
||||
__slots__ = ["shape", "dtype"]
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
def join(self, other):
|
||||
assert core.symbolic_equal_shape(self.shape, other.shape)
|
||||
assert self.dtype == other.dtype
|
||||
return self
|
||||
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def get(tracer, idx=()):
|
||||
from jax._src.state.primitives import ref_get
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def set(tracer, value, idx=()):
|
||||
from jax._src.state.primitives import ref_set
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
def _getitem(self, tracer, idx) -> Array:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
from jax._src.state.primitives import ref_get
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
def _setitem(self, tracer, idx, value) -> None:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
from jax._src.state.primitives import ref_set
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
a = core.ShapedArray(self.shape, self.dtype)
|
||||
return f'Ref{{{a.str_short()}}}'
|
||||
|
||||
def at_least_vspace(self):
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype))
|
||||
|
||||
|
||||
core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval
|
@ -871,6 +871,15 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "state_test",
|
||||
srcs = ["state_test.py"],
|
||||
enable_configs = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "clear_backends_test",
|
||||
srcs = ["clear_backends_test.py"],
|
||||
|
@ -29,13 +29,11 @@ import jax
|
||||
from jax import core
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax import random
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental import maps
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
@ -2563,320 +2561,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
class ForLoopTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cant_eval_get_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.get_p.bind(jnp.ones(5))
|
||||
|
||||
def test_cant_eval_swap_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.swap_p.bind(jnp.ones(5), jnp.zeros(5))
|
||||
|
||||
def test_cant_eval_addupdate_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
|
||||
|
||||
def test_get_abstract_eval(self):
|
||||
ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32)
|
||||
out_aval, effect = for_loop.get_p.abstract_eval(ref_aval, 0)
|
||||
self.assertSetEqual(effect, {for_loop.State})
|
||||
self.assertTupleEqual(out_aval.shape, (2, 3))
|
||||
self.assertEqual(out_aval.dtype, jnp.float32)
|
||||
|
||||
def test_get_abstract_aval_must_take_in_refs(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.get_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32))
|
||||
|
||||
def test_swap_abstract_eval(self):
|
||||
ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32)
|
||||
val_aval = core.ShapedArray((2, 3), jnp.float32)
|
||||
out_aval, effect = for_loop.swap_p.abstract_eval(ref_aval, val_aval, 0)
|
||||
self.assertSetEqual(effect, {for_loop.State})
|
||||
self.assertTupleEqual(out_aval.shape, (2, 3))
|
||||
self.assertEqual(out_aval.dtype, jnp.float32)
|
||||
|
||||
def test_swap_abstract_eval_must_take_in_refs(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.swap_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3), jnp.float32))
|
||||
|
||||
def test_swap_checks_for_correct_shapes(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.swap_p.abstract_eval(
|
||||
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32))
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.swap_p.abstract_eval(
|
||||
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3, 4), jnp.float32))
|
||||
for_loop.swap_p.abstract_eval(
|
||||
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32), 1)
|
||||
|
||||
def test_addupdate_abstract_eval(self):
|
||||
ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32)
|
||||
val_aval = core.ShapedArray((2, 3), jnp.float32)
|
||||
out_avals, effect = for_loop.addupdate_p.abstract_eval(ref_aval, val_aval,
|
||||
0)
|
||||
self.assertSetEqual(effect, {for_loop.State})
|
||||
self.assertListEqual(out_avals, [])
|
||||
|
||||
def test_addupdate_abstract_eval_must_take_in_refs(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.addupdate_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3), jnp.float32))
|
||||
|
||||
def test_addupdate_checks_for_correct_shapes(self):
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.addupdate_p.abstract_eval(
|
||||
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32))
|
||||
with self.assertRaises(ValueError):
|
||||
for_loop.addupdate_p.abstract_eval(
|
||||
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3, 4), jnp.float32))
|
||||
for_loop.addupdate_p.abstract_eval(
|
||||
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32), 1)
|
||||
|
||||
def test_can_represent_get_and_swap_in_jaxprs(self):
|
||||
|
||||
def body(x):
|
||||
x[()] = jnp.int32(1)
|
||||
x[()] = jnp.int32(2)
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.swap_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.swap_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, for_loop.get_p)
|
||||
|
||||
def test_can_represent_addupdate_in_jaxprs(self):
|
||||
|
||||
def body(x):
|
||||
for_loop.ref_addupdate(x, (), jnp.int32(1))
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.addupdate_p)
|
||||
|
||||
def test_get_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
x = x_ref[()]
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_set_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
x_ref[()] = jnp.int32(2)
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_swap_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
x = for_loop.ref_swap(x_ref, (), jnp.int32(2))
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_addupdate_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
for_loop.ref_addupdate(x_ref, (), jnp.int32(2))
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_get_jvp(self):
|
||||
|
||||
def f(r):
|
||||
x = r[()]
|
||||
return jnp.cos(x)
|
||||
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.get_p)
|
||||
|
||||
def test_swap_jvp(self):
|
||||
|
||||
def f(a):
|
||||
x = a[()]
|
||||
a[()] = jnp.sin(x)
|
||||
return a[()]
|
||||
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.get_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
|
||||
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
|
||||
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
|
||||
self.assertEqual(jaxpr.eqns[5].primitive, for_loop.swap_p)
|
||||
self.assertEqual(jaxpr.eqns[6].primitive, for_loop.swap_p)
|
||||
|
||||
def test_addupdate_jvp(self):
|
||||
|
||||
def f(a):
|
||||
for_loop.ref_addupdate(a, (), 1.)
|
||||
return a[()]
|
||||
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, for_loop.get_p)
|
||||
self.assertEqual(jaxpr.eqns[3].primitive, for_loop.get_p)
|
||||
|
||||
def test_discharge_get(self):
|
||||
def f(a_ref):
|
||||
a = for_loop.ref_get(a_ref, ())
|
||||
return [a + 1]
|
||||
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 2)
|
||||
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
|
||||
# Should be able to evaluate this jaxpr
|
||||
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
|
||||
jnp.float32(1.)), [2., 1.])
|
||||
|
||||
def test_discharge_get_with_slice(self):
|
||||
def f(a_ref):
|
||||
a = for_loop.ref_get(a_ref, (0, 1))
|
||||
return [a + 1]
|
||||
in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 2)
|
||||
self.assertIn(lax.dynamic_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
# Should be able to evaluate this jaxpr
|
||||
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
||||
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((outval == inval[0, 1] + 1).all())
|
||||
self.assertTrue((refval == inval).all())
|
||||
|
||||
def test_discharge_set(self):
|
||||
def f(a_ref, b):
|
||||
for_loop.ref_set(a_ref, (), b + 1)
|
||||
return []
|
||||
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that ignores the first
|
||||
# value and returns second value plus 1.
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 2)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
|
||||
jnp.float32(1.))[0], 2.)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
|
||||
jnp.float32(1.))[0], 2.)
|
||||
|
||||
def test_discharge_set_with_slice(self):
|
||||
def f(a_ref):
|
||||
for_loop.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertIn(lax.dynamic_update_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
self.assertIn(lax.dynamic_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
# Should be able to evaluate this jaxpr
|
||||
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
||||
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
|
||||
|
||||
def test_discharge_addupdate(self):
|
||||
def f(a_ref, b):
|
||||
for_loop.ref_addupdate(a_ref, (), b + 1)
|
||||
return []
|
||||
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that adds the first value,
|
||||
# second value, and 1.
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 2)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
|
||||
jnp.float32(1.))[0], 2.)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
|
||||
jnp.float32(1.))[0], 4.)
|
||||
|
||||
def test_discharge_addupdate_with_slice(self):
|
||||
def f(a_ref):
|
||||
for_loop.ref_addupdate(a_ref, (0, 1),
|
||||
jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertIn(lax.dynamic_update_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
self.assertIn(lax.add_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
self.assertIn(lax.dynamic_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
||||
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
|
||||
|
||||
def test_discharge_jaxpr_with_multiple_outputs(self):
|
||||
def f(a_ref):
|
||||
a = for_loop.ref_get(a_ref, ())
|
||||
b = a + 1
|
||||
return [a, b]
|
||||
in_avals = [for_loop.ShapedArrayRef((4,), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 3)
|
||||
inval = jnp.arange(4., dtype=jnp.float32)
|
||||
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((a == inval).all())
|
||||
self.assertTrue((b == inval + 1).all())
|
||||
self.assertTrue((refval == inval).all())
|
||||
|
||||
def test_for_loop_impl_trivial(self):
|
||||
out = for_loop.for_loop(5, lambda i, _: None, None)
|
||||
self.assertEqual(out, None)
|
||||
|
348
tests/state_test.py
Normal file
348
tests/state_test.py
Normal file
@ -0,0 +1,348 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src import state
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cant_eval_get_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.get_p.bind(jnp.ones(5))
|
||||
|
||||
def test_cant_eval_swap_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.swap_p.bind(jnp.ones(5), jnp.zeros(5))
|
||||
|
||||
def test_cant_eval_addupdate_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
|
||||
|
||||
def test_get_abstract_eval(self):
|
||||
ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32)
|
||||
out_aval, effect = state.get_p.abstract_eval(ref_aval, 0)
|
||||
self.assertSetEqual(effect, {state.StateEffect})
|
||||
self.assertTupleEqual(out_aval.shape, (2, 3))
|
||||
self.assertEqual(out_aval.dtype, jnp.float32)
|
||||
|
||||
def test_get_abstract_aval_must_take_in_refs(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.get_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32))
|
||||
|
||||
def test_swap_abstract_eval(self):
|
||||
ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32)
|
||||
val_aval = core.ShapedArray((2, 3), jnp.float32)
|
||||
out_aval, effect = state.swap_p.abstract_eval(ref_aval, val_aval, 0)
|
||||
self.assertSetEqual(effect, {state.StateEffect})
|
||||
self.assertTupleEqual(out_aval.shape, (2, 3))
|
||||
self.assertEqual(out_aval.dtype, jnp.float32)
|
||||
|
||||
def test_swap_abstract_eval_must_take_in_refs(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.swap_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3), jnp.float32))
|
||||
|
||||
def test_swap_checks_for_correct_shapes(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.swap_p.abstract_eval(
|
||||
state.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32))
|
||||
with self.assertRaises(ValueError):
|
||||
state.swap_p.abstract_eval(
|
||||
state.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3, 4), jnp.float32))
|
||||
state.swap_p.abstract_eval(
|
||||
state.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32), 1)
|
||||
|
||||
def test_addupdate_abstract_eval(self):
|
||||
ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32)
|
||||
val_aval = core.ShapedArray((2, 3), jnp.float32)
|
||||
out_avals, effect = state.addupdate_p.abstract_eval(ref_aval, val_aval,
|
||||
0)
|
||||
self.assertSetEqual(effect, {state.StateEffect})
|
||||
self.assertListEqual(out_avals, [])
|
||||
|
||||
def test_addupdate_abstract_eval_must_take_in_refs(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.addupdate_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3), jnp.float32))
|
||||
|
||||
def test_addupdate_checks_for_correct_shapes(self):
|
||||
with self.assertRaises(ValueError):
|
||||
state.addupdate_p.abstract_eval(
|
||||
state.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32))
|
||||
with self.assertRaises(ValueError):
|
||||
state.addupdate_p.abstract_eval(
|
||||
state.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((1, 2, 3, 4), jnp.float32))
|
||||
state.addupdate_p.abstract_eval(
|
||||
state.ShapedArrayRef((1, 2, 3), jnp.float32),
|
||||
core.ShapedArray((2, 3), jnp.float32), 1)
|
||||
|
||||
def test_can_represent_get_and_swap_in_jaxprs(self):
|
||||
|
||||
def body(x):
|
||||
x[()] = jnp.int32(1)
|
||||
x[()] = jnp.int32(2)
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.swap_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
|
||||
|
||||
def test_can_represent_addupdate_in_jaxprs(self):
|
||||
|
||||
def body(x):
|
||||
state.ref_addupdate(x, (), jnp.int32(1))
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
|
||||
|
||||
def test_get_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
x = x_ref[()]
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_set_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
x_ref[()] = jnp.int32(2)
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_swap_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
x = state.ref_swap(x_ref, (), jnp.int32(2))
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_addupdate_custom_pretty_printing_rule(self):
|
||||
def body(x_ref):
|
||||
state.ref_addupdate(x_ref, (), jnp.int32(2))
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_get_jvp(self):
|
||||
|
||||
def f(r):
|
||||
x = r[()]
|
||||
return jnp.cos(x)
|
||||
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
|
||||
|
||||
def test_swap_jvp(self):
|
||||
|
||||
def f(a):
|
||||
x = a[()]
|
||||
a[()] = jnp.sin(x)
|
||||
return a[()]
|
||||
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
|
||||
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
|
||||
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
|
||||
self.assertEqual(jaxpr.eqns[5].primitive, state.swap_p)
|
||||
self.assertEqual(jaxpr.eqns[6].primitive, state.swap_p)
|
||||
|
||||
def test_addupdate_jvp(self):
|
||||
|
||||
def f(a):
|
||||
state.ref_addupdate(a, (), 1.)
|
||||
return a[()]
|
||||
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
|
||||
self.assertEqual(jaxpr.eqns[3].primitive, state.get_p)
|
||||
|
||||
class StateDischargeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_discharge_get(self):
|
||||
def f(a_ref):
|
||||
a = state.ref_get(a_ref, ())
|
||||
return [a + 1]
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 2)
|
||||
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
|
||||
# Should be able to evaluate this jaxpr
|
||||
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
|
||||
jnp.float32(1.)), [2., 1.])
|
||||
|
||||
def test_discharge_get_with_slice(self):
|
||||
def f(a_ref):
|
||||
a = state.ref_get(a_ref, (0, 1))
|
||||
return [a + 1]
|
||||
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 2)
|
||||
self.assertIn(lax.dynamic_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
# Should be able to evaluate this jaxpr
|
||||
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
||||
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((outval == inval[0, 1] + 1).all())
|
||||
self.assertTrue((refval == inval).all())
|
||||
|
||||
def test_discharge_set(self):
|
||||
def f(a_ref, b):
|
||||
state.ref_set(a_ref, (), b + 1)
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that ignores the first
|
||||
# value and returns second value plus 1.
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 2)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
|
||||
jnp.float32(1.))[0], 2.)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
|
||||
jnp.float32(1.))[0], 2.)
|
||||
|
||||
def test_discharge_set_with_slice(self):
|
||||
def f(a_ref):
|
||||
state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertIn(lax.dynamic_update_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
self.assertIn(lax.dynamic_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
# Should be able to evaluate this jaxpr
|
||||
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
||||
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
|
||||
|
||||
def test_discharge_addupdate(self):
|
||||
def f(a_ref, b):
|
||||
state.ref_addupdate(a_ref, (), b + 1)
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that adds the first value,
|
||||
# second value, and 1.
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 2)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
|
||||
jnp.float32(1.))[0], 2.)
|
||||
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
|
||||
jnp.float32(1.))[0], 4.)
|
||||
|
||||
def test_discharge_addupdate_with_slice(self):
|
||||
def f(a_ref):
|
||||
state.ref_addupdate(a_ref, (0, 1),
|
||||
jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertIn(lax.dynamic_update_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
self.assertIn(lax.add_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
self.assertIn(lax.dynamic_slice_p,
|
||||
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
||||
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
||||
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
|
||||
|
||||
def test_discharge_jaxpr_with_multiple_outputs(self):
|
||||
def f(a_ref):
|
||||
a = state.ref_get(a_ref, ())
|
||||
b = a + 1
|
||||
return [a, b]
|
||||
in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
self.assertLen(discharged_jaxpr.outvars, 3)
|
||||
inval = jnp.arange(4., dtype=jnp.float32)
|
||||
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
||||
self.assertTrue((a == inval).all())
|
||||
self.assertTrue((b == inval + 1).all())
|
||||
self.assertTrue((refval == inval).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user