Merge pull request #11701 from sharadmv:state

PiperOrigin-RevId: 464658336
This commit is contained in:
jax authors 2022-08-01 17:05:43 -07:00
commit 01819257f6
8 changed files with 866 additions and 661 deletions

View File

@ -15,7 +15,7 @@
from functools import partial
import operator
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar
from jax import core
from jax import lax
@ -29,8 +29,8 @@ from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
from jax._src import ad_util
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import state
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list)
import jax.numpy as jnp
@ -49,350 +49,13 @@ T = TypeVar('T')
class Ref(Generic[T]): pass
Array = Any
## State effect
StateEffect = state.StateEffect
ShapedArrayRef = state.ShapedArrayRef
ref_set = state.ref_set
ref_get = state.ref_get
ref_addupdate = state.ref_addupdate
discharge_state = state.discharge_state
class StateEffect: pass
State = StateEffect()
## get/swap/addupdate implementations
# `get` reads a value from a `Ref` type, a.k.a.:
# a = get_p.bind(x)
# or we can read using indices:
# a = get_p.bind(x, 0, 1)
# Staging out `a = get_p.bind(x)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")
def _get_impl(ref: Ref, *idx: int):
del ref, idx
raise ValueError("Can't evaluate `get` outside a stateful context.")
get_p.def_impl(_get_impl)
def ref_get(ref: Ref, idx: Tuple[int]) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
idx = map(jnp.int32, idx)
return get_p.bind(ref, *idx)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
# It generalizes the setting operation for a `Ref` as we can ignore the return
# value:
# _ = swap_p.bind(x, a)
# `swap_p` also takes in index arguments following the value, i.e.:
# _ = swap_p.bind(x, a, 0, 1)
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# b:f32[3], x:Ref{f32[3]} <- x, a
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")
def _swap_impl(ref: Ref, value: Array, *idx: int):
del ref, idx, value
raise ValueError("Can't evaluate `swap` outside a stateful context.")
swap_p.def_impl(_swap_impl)
def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
idx = map(jnp.int32, idx)
return swap_p.bind(ref, value, *idx)
def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref, idx, value)
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically,
# ```
# addupdate ref a *idx
# ```
# is equivalent to
# ```
# b = get ref *idx
# c = add b x
# _ = swap ref c *idx
# ```
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
def _addupdate_impl(ref: Ref, value: Array, *idx: int):
del ref, idx, value
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
def ref_addupdate(ref: Ref, idx: Tuple[int], x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
return addupdate_p.bind(ref, x, *idx)
## get/set/addupdate abstract evaluation rules
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
# A `ShapedArrayRef` is a abstract value for mutable containers of array types
class ShapedArrayRef(core.AbstractValue):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def join(self, other):
assert core.symbolic_equal_shape(self.shape, other.shape)
assert self.dtype == other.dtype
return self
def _getitem(self, tracer, idx) -> Array:
if not isinstance(idx, tuple):
idx = idx,
return ref_get(tracer, idx)
def _setitem(self, tracer, idx, val) -> None:
if not isinstance(idx, tuple):
idx = idx,
return ref_set(tracer, idx, val)
def __repr__(self) -> str:
a = core.ShapedArray(self.shape, self.dtype)
return f'Ref{{{a.str_short()}}}'
def at_least_vspace(self):
return self
core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval
def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx: int):
if not isinstance(ref_aval, ShapedArrayRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State}
get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
*idx: int):
if not isinstance(ref_aval, ShapedArrayRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
expected_output_shape = ref_aval.shape[len(idx):]
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State}
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
val_aval: core.AbstractValue,
*idx: int):
if not isinstance(ref_aval, ShapedArrayRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
expected_output_shape = ref_aval.shape[len(idx):]
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
return [], {State}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
foreground=pp.Color.GREEN)
def _get_pp_rule(eqn, context, settings):
# Pretty prints `a = get x i` as `a <- x[i]`
y, = eqn.outvars
x, *idx = eqn.invars
idx = ','.join(core.pp_var(i, context) for i in idx)
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return [lhs, pp.text(' <- '), pp_ref(pp.concat([
pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']')
]))]
core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings):
y, = eqn.outvars
x, v, *idx = eqn.invars
idx = ','.join(core.pp_var(i, context) for i in idx)
if type(y) is core.DropVar:
# In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v`
del y
return [
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), pp.text(idx), pp.text(']')
])), pp.text(' <- '), pp.text(core.pp_var(v, context))]
else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
x_i = pp.concat([pp.text(core.pp_var(x, context)),
pp.text('['), pp.text(idx), pp.text(']')])
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return [y, pp.text(', '), x_i, pp.text(' <- '),
x_i, pp.text(', '), pp.text(core.pp_var(v, context))]
core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _addupdate_pp_rule(eqn, context, settings):
# pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars
x, v, *idx = eqn.invars
idx = ','.join(core.pp_var(i, context) for i in idx)
return [
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), pp.text(idx), pp.text(']')
])), pp.text(' += '), pp.text(core.pp_var(v, context))]
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
def _get_jvp(primals: List[Any], tangents: List[Any]):
ref_primal, *idx = primals
assert isinstance(ref_primal.aval, ShapedArrayRef)
ref_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
return ref_get(ref_primal, idx), ref_get(ref_tangent, idx) # type: ignore[arg-type]
ad.primitive_jvps[get_p] = _get_jvp
def _swap_jvp(primals: List[Any], tangents: List[Any]):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, ShapedArrayRef)
ref_tangent, x_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
x_tangent = ad_util.instantiate(x_tangent)
return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type]
ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type]
ad.primitive_jvps[swap_p] = _swap_jvp
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]):
ref_primal, x_primal, *idx = primals
ref_tangent, x_tangent, *_ = tangents
x_tangent = ad_util.instantiate(x_tangent)
addupdate_p.bind(ref_primal, x_primal, *idx)
addupdate_p.bind(ref_tangent, x_tangent, *idx)
return [], []
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
## get/swap/addupdate transpose rules
def _get_transpose(g, ref, *idx):
# get transpose is addupdate
if type(g) is not ad_util.Zero:
ref_addupdate(ref, idx, g)
return [None] + [None] * len(idx)
ad.primitive_transposes[get_p] = _get_transpose
def _swap_transpose(g, ref, x, *idx):
# swap transpose is swap
x_bar = ref_swap(ref, idx, ad_util.instantiate(g))
return [None, x_bar] + [None] * len(idx)
ad.primitive_transposes[swap_p] = _swap_transpose
## Discharging state
# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values
# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr
# into a "pure" jaxpr that takes in and outputs values and no longer has the
# `State` effect.
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]:
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype)
if type(v.aval) is ShapedArrayRef
else v.aval for v in jaxpr.invars]
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, consts))
new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
return new_jaxpr, new_consts
def _dynamic_index(x, idx):
if not idx: return x
ndim = len(x.shape)
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
sizes = (1,) * len(idx) + x.shape[len(idx):]
out = lax.dynamic_slice(x, starts, sizes)
return out.reshape(x.shape[len(idx):])
def _dynamic_update_index(x, idx, val):
if not idx: return val
ndim = len(x.shape)
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
update = val.reshape((1,) * len(idx) + x.shape[len(idx):])
return lax.dynamic_update_slice(x, update, starts)
def _eval_jaxpr_discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any],
*args: Any):
env: Dict[core.Var, Any] = {}
def read(v: core.Atom) -> Any:
if type(v) is core.Literal:
return v.val
assert isinstance(v, core.Var)
return env[v]
def write(v: core.Var, val: Any) -> None:
env[v] = val
map(write, jaxpr.constvars, consts)
# Here some args may correspond to `Ref` avals but they'll be treated like
# regular values in this interpreter.
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
if eqn.primitive is get_p:
# `y <- x[i]` becomes `y = ds x i`
x, *idx = in_vals
write(eqn.outvars[0], _dynamic_index(x, idx))
elif eqn.primitive is swap_p:
# `z, x[i] <- x[i], val` becomes:
# z = ds x i
# x = dus x i val
x, val, *idx = in_vals
write(eqn.outvars[0], _dynamic_index(x, idx))
assert isinstance(eqn.invars[0], core.Var)
write(eqn.invars[0], _dynamic_update_index(x, idx, val))
elif eqn.primitive is addupdate_p:
# `x[i] += val` becomes:
# y = ds x i
# z = y + val
# x = dus x i z
x, val, *idx = in_vals
ans = _dynamic_update_index(x, idx, val + _dynamic_index(x, idx))
assert isinstance(eqn.invars[0], core.Var)
write(eqn.invars[0], ans)
else:
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here
# we assume any higher-order primitives inside of the jaxpr are *not*
# stateful.
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
# By convention, we return the outputs of the jaxpr first and then the final
# values of the `Ref`s. Callers to this function should be able to split
# them up by looking at `len(jaxpr.outvars)`.
out_vals = map(read, jaxpr.outvars)
ref_vals = map(
read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef])
return out_vals + ref_vals
## `for_loop` implementation

View File

@ -0,0 +1,19 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state."""
from jax._src.state.types import ShapedArrayRef, StateEffect
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
ref_addupdate, get_p, swap_p,
addupdate_p)
from jax._src.state.discharge import discharge_state

121
jax/_src/state/discharge.py Normal file
View File

@ -0,0 +1,121 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for discharging state primitives."""
from functools import partial
from typing import Any, Dict, List, Sequence, Tuple
from jax import core
from jax import lax
from jax import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax._src.util import safe_map, safe_zip
from jax._src.state.types import ShapedArrayRef
from jax._src.state.primitives import get_p, swap_p, addupdate_p
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
## Discharging state
# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values
# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr
# into a "pure" jaxpr that takes in and outputs values and no longer has the
# `StateEffect` effect.
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]:
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype)
if type(v.aval) is ShapedArrayRef
else v.aval for v in jaxpr.invars]
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, consts))
new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
return new_jaxpr, new_consts
def _dynamic_index(x, idx):
if not idx: return x
ndim = len(x.shape)
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
sizes = (1,) * len(idx) + x.shape[len(idx):]
out = lax.dynamic_slice(x, starts, sizes)
return out.reshape(x.shape[len(idx):])
def _dynamic_update_index(x, idx, val):
if not idx: return val
ndim = len(x.shape)
starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
update = val.reshape((1,) * len(idx) + x.shape[len(idx):])
return lax.dynamic_update_slice(x, update, starts)
def _eval_jaxpr_discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any],
*args: Any):
env: Dict[core.Var, Any] = {}
def read(v: core.Atom) -> Any:
if type(v) is core.Literal:
return v.val
assert isinstance(v, core.Var)
return env[v]
def write(v: core.Var, val: Any) -> None:
env[v] = val
map(write, jaxpr.constvars, consts)
# Here some args may correspond to `Ref` avals but they'll be treated like
# regular values in this interpreter.
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
if eqn.primitive is get_p:
# `y <- x[i]` becomes `y = ds x i`
x, *idx = in_vals
write(eqn.outvars[0], _dynamic_index(x, idx))
elif eqn.primitive is swap_p:
# `z, x[i] <- x[i], val` becomes:
# z = ds x i
# x = dus x i val
x, val, *idx = in_vals
write(eqn.outvars[0], _dynamic_index(x, idx))
assert isinstance(eqn.invars[0], core.Var)
write(eqn.invars[0], _dynamic_update_index(x, idx, val))
elif eqn.primitive is addupdate_p:
# `x[i] += val` becomes:
# y = ds x i
# z = y + val
# x = dus x i z
x, val, *idx = in_vals
ans = _dynamic_update_index(x, idx, val + _dynamic_index(x, idx))
assert isinstance(eqn.invars[0], core.Var)
write(eqn.invars[0], ans)
else:
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here
# we assume any higher-order primitives inside of the jaxpr are *not*
# stateful.
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
# By convention, we return the outputs of the jaxpr first and then the final
# values of the `Ref`s. Callers to this function should be able to split
# them up by looking at `len(jaxpr.outvars)`.
out_vals = map(read, jaxpr.outvars)
ref_vals = map(
read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef])
return out_vals + ref_vals

View File

@ -0,0 +1,254 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state primitives."""
from functools import partial
from typing import Any, Generic, List, Tuple, TypeVar
from jax import core
from jax._src import ad_util
from jax._src import pretty_printer as pp
from jax._src.util import safe_map, safe_zip
from jax.interpreters import ad
import jax.numpy as jnp
from jax._src.state.types import ShapedArrayRef, StateEffect
## General utilities
Array = Any
T = TypeVar('T')
class Ref(Generic[T]): pass
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
## get/swap/addupdate implementations
# `get` reads a value from a `Ref` type, a.k.a.:
# a = get_p.bind(x)
# or we can read using indices:
# a = get_p.bind(x, 0, 1)
# Staging out `a = get_p.bind(x)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")
def _get_impl(ref: Ref, *idx: int):
del ref, idx
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
def ref_get(ref: Ref, idx: Tuple[int]) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
idx = map(jnp.int32, idx)
return get_p.bind(ref, *idx)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
# It generalizes the setting operation for a `Ref` as we can ignore the return
# value:
# _ = swap_p.bind(x, a)
# `swap_p` also takes in index arguments following the value, i.e.:
# _ = swap_p.bind(x, a, 0, 1)
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# b:f32[3], x:Ref{f32[3]} <- x, a
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")
def _swap_impl(ref: Ref, value: Array, *idx: int):
del ref, value, idx
raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl)
def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
idx = map(jnp.int32, idx)
return swap_p.bind(ref, value, *idx)
def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref, idx, value)
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically,
# ```
# addupdate ref a *idx
# ```
# is equivalent to
# ```
# b = get ref *idx
# c = add b x
# _ = swap ref c *idx
# ```
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
def _addupdate_impl(ref: Ref, value: Array, *idx: int):
del ref, idx, value
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
def ref_addupdate(ref: Ref, idx: Tuple[int], x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
return addupdate_p.bind(ref, x, *idx)
## get/set/addupdate abstract evaluation rules
def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx: int):
if not isinstance(ref_aval, ShapedArrayRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
return (core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype),
{StateEffect})
get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
*idx: int):
if not isinstance(ref_aval, ShapedArrayRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
expected_output_shape = ref_aval.shape[len(idx):]
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
return (core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype),
{StateEffect})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
val_aval: core.AbstractValue,
*idx: int):
if not isinstance(ref_aval, ShapedArrayRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
expected_output_shape = ref_aval.shape[len(idx):]
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
return [], {StateEffect}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
foreground=pp.Color.GREEN)
def _get_pp_rule(eqn, context, settings):
# Pretty prints `a = get x i` as `x[i] <- a`
y, = eqn.outvars
x, *idx = eqn.invars
idx = ','.join(core.pp_var(i, context) for i in idx)
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return [lhs, pp.text(' <- '), pp_ref(pp.concat([
pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']')
]))]
core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings):
y, = eqn.outvars
x, v, *idx = eqn.invars
idx = ','.join(core.pp_var(i, context) for i in idx)
if type(y) is core.DropVar:
# In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v`
del y
return [
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), pp.text(idx), pp.text(']')
])), pp.text(' <- '), pp.text(core.pp_var(v, context))]
else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
x_i = pp.concat([pp.text(core.pp_var(x, context)),
pp.text('['), pp.text(idx), pp.text(']')])
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return [y, pp.text(', '), x_i, pp.text(' <- '),
x_i, pp.text(', '), pp.text(core.pp_var(v, context))]
core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _addupdate_pp_rule(eqn, context, settings):
# pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars
x, v, *idx = eqn.invars
idx = ','.join(core.pp_var(i, context) for i in idx)
return [
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), pp.text(idx), pp.text(']')
])), pp.text(' += '), pp.text(core.pp_var(v, context))]
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
def _get_jvp(primals: List[Any], tangents: List[Any]):
ref_primal, *idx = primals
assert isinstance(ref_primal.aval, ShapedArrayRef)
ref_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
return ref_get(ref_primal, idx), ref_get(ref_tangent, idx) # type: ignore[arg-type]
ad.primitive_jvps[get_p] = _get_jvp
def _swap_jvp(primals: List[Any], tangents: List[Any]):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, ShapedArrayRef)
ref_tangent, x_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
x_tangent = ad_util.instantiate(x_tangent)
return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type]
ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type]
ad.primitive_jvps[swap_p] = _swap_jvp
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]):
ref_primal, x_primal, *idx = primals
ref_tangent, x_tangent, *_ = tangents
x_tangent = ad_util.instantiate(x_tangent)
addupdate_p.bind(ref_primal, x_primal, *idx)
addupdate_p.bind(ref_tangent, x_tangent, *idx)
return [], []
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
## get/swap/addupdate transpose rules
def _get_transpose(g, ref, *idx):
# get transpose is addupdate
if type(g) is not ad_util.Zero:
ref_addupdate(ref, idx, g)
return [None] + [None] * len(idx)
ad.primitive_transposes[get_p] = _get_transpose
def _swap_transpose(g, ref, x, *idx):
# swap transpose is swap
x_bar = ref_swap(ref, idx, ad_util.instantiate(g))
return [None, x_bar] + [None] * len(idx)
ad.primitive_transposes[swap_p] = _swap_transpose

107
jax/_src/state/types.py Normal file
View File

@ -0,0 +1,107 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state types."""
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from jax import api_util
from jax import core
from jax import linear_util as lu
from jax import tree_util
from jax._src import ad_util
from jax._src import device_array
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src.lib import xla_bridge, xla_client
from jax._src.util import safe_map, safe_zip, split_list
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import numpy as np
xc = xla_client
xb = xla_bridge
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Array = Any
class _StateEffect:
def __repr__(self):
return "State"
__str__ = __repr__
StateEffect = _StateEffect()
# ## `Ref`s
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
# A `ShapedArrayRef` is a abstract value for mutable containers of array types
class ShapedArrayRef(core.AbstractValue):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def join(self, other):
assert core.symbolic_equal_shape(self.shape, other.shape)
assert self.dtype == other.dtype
return self
@core.aval_method
@staticmethod
def get(tracer, idx=()):
from jax._src.state.primitives import ref_get
return ref_get(tracer, idx)
@core.aval_method
@staticmethod
def set(tracer, value, idx=()):
from jax._src.state.primitives import ref_set
return ref_set(tracer, idx, value)
def _getitem(self, tracer, idx) -> Array:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_get
return ref_get(tracer, idx)
def _setitem(self, tracer, idx, value) -> None:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_set
return ref_set(tracer, idx, value)
def __repr__(self) -> str:
a = core.ShapedArray(self.shape, self.dtype)
return f'Ref{{{a.str_short()}}}'
def at_least_vspace(self):
return self
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape)
def __hash__(self):
return hash((self.shape, self.dtype))
core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval

View File

@ -871,6 +871,15 @@ jax_test(
],
)
jax_test(
name = "state_test",
srcs = ["state_test.py"],
enable_configs = [
"gpu",
"cpu",
],
)
jax_test(
name = "clear_backends_test",
srcs = ["clear_backends_test.py"],

View File

@ -29,13 +29,11 @@ import jax
from jax import core
from jax.errors import UnexpectedTracerError
from jax import lax
from jax import linear_util as lu
from jax import random
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.util import unzip2
from jax.experimental import maps
from jax.interpreters import partial_eval as pe
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
@ -2563,320 +2561,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
class ForLoopTest(jtu.JaxTestCase):
def test_cant_eval_get_primitive(self):
with self.assertRaises(ValueError):
for_loop.get_p.bind(jnp.ones(5))
def test_cant_eval_swap_primitive(self):
with self.assertRaises(ValueError):
for_loop.swap_p.bind(jnp.ones(5), jnp.zeros(5))
def test_cant_eval_addupdate_primitive(self):
with self.assertRaises(ValueError):
for_loop.addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
def test_get_abstract_eval(self):
ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32)
out_aval, effect = for_loop.get_p.abstract_eval(ref_aval, 0)
self.assertSetEqual(effect, {for_loop.State})
self.assertTupleEqual(out_aval.shape, (2, 3))
self.assertEqual(out_aval.dtype, jnp.float32)
def test_get_abstract_aval_must_take_in_refs(self):
with self.assertRaises(ValueError):
for_loop.get_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32))
def test_swap_abstract_eval(self):
ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32)
val_aval = core.ShapedArray((2, 3), jnp.float32)
out_aval, effect = for_loop.swap_p.abstract_eval(ref_aval, val_aval, 0)
self.assertSetEqual(effect, {for_loop.State})
self.assertTupleEqual(out_aval.shape, (2, 3))
self.assertEqual(out_aval.dtype, jnp.float32)
def test_swap_abstract_eval_must_take_in_refs(self):
with self.assertRaises(ValueError):
for_loop.swap_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3), jnp.float32))
def test_swap_checks_for_correct_shapes(self):
with self.assertRaises(ValueError):
for_loop.swap_p.abstract_eval(
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32))
with self.assertRaises(ValueError):
for_loop.swap_p.abstract_eval(
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3, 4), jnp.float32))
for_loop.swap_p.abstract_eval(
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32), 1)
def test_addupdate_abstract_eval(self):
ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32)
val_aval = core.ShapedArray((2, 3), jnp.float32)
out_avals, effect = for_loop.addupdate_p.abstract_eval(ref_aval, val_aval,
0)
self.assertSetEqual(effect, {for_loop.State})
self.assertListEqual(out_avals, [])
def test_addupdate_abstract_eval_must_take_in_refs(self):
with self.assertRaises(ValueError):
for_loop.addupdate_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3), jnp.float32))
def test_addupdate_checks_for_correct_shapes(self):
with self.assertRaises(ValueError):
for_loop.addupdate_p.abstract_eval(
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32))
with self.assertRaises(ValueError):
for_loop.addupdate_p.abstract_eval(
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3, 4), jnp.float32))
for_loop.addupdate_p.abstract_eval(
for_loop.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32), 1)
def test_can_represent_get_and_swap_in_jaxprs(self):
def body(x):
x[()] = jnp.int32(1)
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.swap_p)
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.swap_p)
self.assertEqual(jaxpr.eqns[2].primitive, for_loop.get_p)
def test_can_represent_addupdate_in_jaxprs(self):
def body(x):
for_loop.ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.addupdate_p)
def test_get_custom_pretty_printing_rule(self):
def body(x_ref):
x = x_ref[()]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self):
def body(x_ref):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def test_swap_custom_pretty_printing_rule(self):
def body(x_ref):
x = for_loop.ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def test_addupdate_custom_pretty_printing_rule(self):
def body(x_ref):
for_loop.ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
def test_get_jvp(self):
def f(r):
x = r[()]
return jnp.cos(x)
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.get_p)
def test_swap_jvp(self):
def f(a):
x = a[()]
a[()] = jnp.sin(x)
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
self.assertEqual(jaxpr.eqns[5].primitive, for_loop.swap_p)
self.assertEqual(jaxpr.eqns[6].primitive, for_loop.swap_p)
def test_addupdate_jvp(self):
def f(a):
for_loop.ref_addupdate(a, (), 1.)
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, for_loop.addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, for_loop.addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, for_loop.get_p)
self.assertEqual(jaxpr.eqns[3].primitive, for_loop.get_p)
def test_discharge_get(self):
def f(a_ref):
a = for_loop.ref_get(a_ref, ())
return [a + 1]
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
# Should be able to evaluate this jaxpr
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
jnp.float32(1.)), [2., 1.])
def test_discharge_get_with_slice(self):
def f(a_ref):
a = for_loop.ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((outval == inval[0, 1] + 1).all())
self.assertTrue((refval == inval).all())
def test_discharge_set(self):
def f(a_ref, b):
for_loop.ref_set(a_ref, (), b + 1)
return []
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1.
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 2.)
def test_discharge_set_with_slice(self):
def f(a_ref):
for_loop.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
def test_discharge_addupdate(self):
def f(a_ref, b):
for_loop.ref_addupdate(a_ref, (), b + 1)
return []
in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1.
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 4.)
def test_discharge_addupdate_with_slice(self):
def f(a_ref):
for_loop.ref_addupdate(a_ref, (0, 1),
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.add_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
def test_discharge_jaxpr_with_multiple_outputs(self):
def f(a_ref):
a = for_loop.ref_get(a_ref, ())
b = a + 1
return [a, b]
in_avals = [for_loop.ShapedArrayRef((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 3)
inval = jnp.arange(4., dtype=jnp.float32)
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((a == inval).all())
self.assertTrue((b == inval + 1).all())
self.assertTrue((refval == inval).all())
def test_for_loop_impl_trivial(self):
out = for_loop.for_loop(5, lambda i, _: None, None)
self.assertEqual(out, None)

348
tests/state_test.py Normal file
View File

@ -0,0 +1,348 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax._src import state
config.parse_flags_with_absl()
class StatePrimitivesTest(jtu.JaxTestCase):
def test_cant_eval_get_primitive(self):
with self.assertRaises(ValueError):
state.get_p.bind(jnp.ones(5))
def test_cant_eval_swap_primitive(self):
with self.assertRaises(ValueError):
state.swap_p.bind(jnp.ones(5), jnp.zeros(5))
def test_cant_eval_addupdate_primitive(self):
with self.assertRaises(ValueError):
state.addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
def test_get_abstract_eval(self):
ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32)
out_aval, effect = state.get_p.abstract_eval(ref_aval, 0)
self.assertSetEqual(effect, {state.StateEffect})
self.assertTupleEqual(out_aval.shape, (2, 3))
self.assertEqual(out_aval.dtype, jnp.float32)
def test_get_abstract_aval_must_take_in_refs(self):
with self.assertRaises(ValueError):
state.get_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32))
def test_swap_abstract_eval(self):
ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32)
val_aval = core.ShapedArray((2, 3), jnp.float32)
out_aval, effect = state.swap_p.abstract_eval(ref_aval, val_aval, 0)
self.assertSetEqual(effect, {state.StateEffect})
self.assertTupleEqual(out_aval.shape, (2, 3))
self.assertEqual(out_aval.dtype, jnp.float32)
def test_swap_abstract_eval_must_take_in_refs(self):
with self.assertRaises(ValueError):
state.swap_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3), jnp.float32))
def test_swap_checks_for_correct_shapes(self):
with self.assertRaises(ValueError):
state.swap_p.abstract_eval(
state.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32))
with self.assertRaises(ValueError):
state.swap_p.abstract_eval(
state.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3, 4), jnp.float32))
state.swap_p.abstract_eval(
state.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32), 1)
def test_addupdate_abstract_eval(self):
ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32)
val_aval = core.ShapedArray((2, 3), jnp.float32)
out_avals, effect = state.addupdate_p.abstract_eval(ref_aval, val_aval,
0)
self.assertSetEqual(effect, {state.StateEffect})
self.assertListEqual(out_avals, [])
def test_addupdate_abstract_eval_must_take_in_refs(self):
with self.assertRaises(ValueError):
state.addupdate_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3), jnp.float32))
def test_addupdate_checks_for_correct_shapes(self):
with self.assertRaises(ValueError):
state.addupdate_p.abstract_eval(
state.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32))
with self.assertRaises(ValueError):
state.addupdate_p.abstract_eval(
state.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((1, 2, 3, 4), jnp.float32))
state.addupdate_p.abstract_eval(
state.ShapedArrayRef((1, 2, 3), jnp.float32),
core.ShapedArray((2, 3), jnp.float32), 1)
def test_can_represent_get_and_swap_in_jaxprs(self):
def body(x):
x[()] = jnp.int32(1)
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.swap_p)
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
def test_can_represent_addupdate_in_jaxprs(self):
def body(x):
state.ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
def test_get_custom_pretty_printing_rule(self):
def body(x_ref):
x = x_ref[()]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self):
def body(x_ref):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def test_swap_custom_pretty_printing_rule(self):
def body(x_ref):
x = state.ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def test_addupdate_custom_pretty_printing_rule(self):
def body(x_ref):
state.ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
def test_get_jvp(self):
def f(r):
x = r[()]
return jnp.cos(x)
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
def test_swap_jvp(self):
def f(a):
x = a[()]
a[()] = jnp.sin(x)
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
self.assertEqual(jaxpr.eqns[5].primitive, state.swap_p)
self.assertEqual(jaxpr.eqns[6].primitive, state.swap_p)
def test_addupdate_jvp(self):
def f(a):
state.ref_addupdate(a, (), 1.)
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[3].primitive, state.get_p)
class StateDischargeTest(jtu.JaxTestCase):
def test_discharge_get(self):
def f(a_ref):
a = state.ref_get(a_ref, ())
return [a + 1]
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
# Should be able to evaluate this jaxpr
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
jnp.float32(1.)), [2., 1.])
def test_discharge_get_with_slice(self):
def f(a_ref):
a = state.ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((outval == inval[0, 1] + 1).all())
self.assertTrue((refval == inval).all())
def test_discharge_set(self):
def f(a_ref, b):
state.ref_set(a_ref, (), b + 1)
return []
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 2.)
def test_discharge_set_with_slice(self):
def f(a_ref):
state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
def test_discharge_addupdate(self):
def f(a_ref, b):
state.ref_addupdate(a_ref, (), b + 1)
return []
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 4.)
def test_discharge_addupdate_with_slice(self):
def f(a_ref):
state.ref_addupdate(a_ref, (0, 1),
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.add_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
def test_discharge_jaxpr_with_multiple_outputs(self):
def f(a_ref):
a = state.ref_get(a_ref, ())
b = a + 1
return [a, b]
in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 3)
inval = jnp.arange(4., dtype=jnp.float32)
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((a == inval).all())
self.assertTrue((b == inval + 1).all())
self.assertTrue((refval == inval).all())
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())