From 8b7daa8095f90973e2557c394c5ff3c8006163bb Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 1 Aug 2022 15:22:15 -0700 Subject: [PATCH] Refactor state out of `for_loop` --- jax/_src/lax/control_flow/for_loop.py | 353 +------------------------- jax/_src/state/__init__.py | 19 ++ jax/_src/state/discharge.py | 121 +++++++++ jax/_src/state/primitives.py | 254 ++++++++++++++++++ jax/_src/state/types.py | 107 ++++++++ tests/lax_control_flow_test.py | 316 ----------------------- tests/state_test.py | 348 +++++++++++++++++++++++++ 7 files changed, 857 insertions(+), 661 deletions(-) create mode 100644 jax/_src/state/__init__.py create mode 100644 jax/_src/state/discharge.py create mode 100644 jax/_src/state/primitives.py create mode 100644 jax/_src/state/types.py create mode 100644 tests/state_test.py diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index a20efd154..93d49ad18 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -15,7 +15,7 @@ from functools import partial import operator -from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar +from typing import Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar from jax import core from jax import lax @@ -29,8 +29,8 @@ from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, treedef_tuple, tree_map, tree_leaves, PyTreeDef) from jax._src import ad_util from jax._src import dtypes -from jax._src import pretty_printer as pp from jax._src import source_info_util +from jax._src import state from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip, split_list) import jax.numpy as jnp @@ -49,350 +49,13 @@ T = TypeVar('T') class Ref(Generic[T]): pass Array = Any -## State effect +StateEffect = state.StateEffect +ShapedArrayRef = state.ShapedArrayRef +ref_set = state.ref_set +ref_get = state.ref_get +ref_addupdate = state.ref_addupdate +discharge_state = state.discharge_state -class StateEffect: pass -State = StateEffect() - -## get/swap/addupdate implementations - -# `get` reads a value from a `Ref` type, a.k.a.: -# a = get_p.bind(x) -# or we can read using indices: -# a = get_p.bind(x, 0, 1) -# Staging out `a = get_p.bind(x)` where the aval of `x` is -# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like -# a:f32[3] <- x[] -get_p = core.Primitive("get") - -def _get_impl(ref: Ref, *idx: int): - del ref, idx - raise ValueError("Can't evaluate `get` outside a stateful context.") -get_p.def_impl(_get_impl) - -def ref_get(ref: Ref, idx: Tuple[int]) -> Array: - """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" - idx = map(jnp.int32, idx) - return get_p.bind(ref, *idx) - -# `swap` mutates a `Ref`, setting its value and returns its previous value. -# b = swap_p.bind(x, a) -# It generalizes the setting operation for a `Ref` as we can ignore the return -# value: -# _ = swap_p.bind(x, a) -# `swap_p` also takes in index arguments following the value, i.e.: -# _ = swap_p.bind(x, a, 0, 1) -# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is -# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is -# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like -# b:f32[3], x:Ref{f32[3]} <- x, a -# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is -# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is -# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j` -# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like -# x:Ref{f32[3]}[i, j] <- a -swap_p = core.Primitive("swap") - -def _swap_impl(ref: Ref, value: Array, *idx: int): - del ref, idx, value - raise ValueError("Can't evaluate `swap` outside a stateful context.") -swap_p.def_impl(_swap_impl) - -def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array: - """Sets a `Ref`'s value and returns the original value.""" - idx = map(jnp.int32, idx) - return swap_p.bind(ref, value, *idx) - -def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None: - """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" - ref_swap(ref, idx, value) - - -# `addupdate_p` mutates a `Ref`, adding a value to its existing value. -# Semantically, -# ``` -# addupdate ref a *idx -# ``` -# is equivalent to -# ``` -# b = get ref *idx -# c = add b x -# _ = swap ref c *idx -# ``` -addupdate_p = core.Primitive('addupdate') -addupdate_p.multiple_results = True - -def _addupdate_impl(ref: Ref, value: Array, *idx: int): - del ref, idx, value - raise ValueError("Can't evaluate `addupdate` outside a stateful context.") -addupdate_p.def_impl(_addupdate_impl) - -def ref_addupdate(ref: Ref, idx: Tuple[int], x: Array) -> None: - """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" - return addupdate_p.bind(ref, x, *idx) - -## get/set/addupdate abstract evaluation rules - -# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. -# A `ShapedArrayRef` is a abstract value for mutable containers of array types -class ShapedArrayRef(core.AbstractValue): - __slots__ = ["shape", "dtype"] - - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - - def join(self, other): - assert core.symbolic_equal_shape(self.shape, other.shape) - assert self.dtype == other.dtype - return self - - def _getitem(self, tracer, idx) -> Array: - if not isinstance(idx, tuple): - idx = idx, - return ref_get(tracer, idx) - - def _setitem(self, tracer, idx, val) -> None: - if not isinstance(idx, tuple): - idx = idx, - return ref_set(tracer, idx, val) - - def __repr__(self) -> str: - a = core.ShapedArray(self.shape, self.dtype) - return f'Ref{{{a.str_short()}}}' - - def at_least_vspace(self): - return self - -core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval - -def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx: int): - if not isinstance(ref_aval, ShapedArrayRef): - raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") - return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State} -get_p.def_effectful_abstract_eval(_get_abstract_eval) - - -def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue, - *idx: int): - if not isinstance(ref_aval, ShapedArrayRef): - raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") - val_aval = core.raise_to_shaped(val_aval) - assert isinstance(val_aval, core.ShapedArray) - expected_output_shape = ref_aval.shape[len(idx):] - if expected_output_shape != val_aval.shape: - raise ValueError("Invalid shape for `swap`. " - f"Ref shape: {ref_aval.shape}. " - f"Value shape: {val_aval.shape}. " - f"Indices: {idx}. ") - if ref_aval.dtype != val_aval.dtype: - raise ValueError("Invalid dtype for `swap`. " - f"Ref dtype: {ref_aval.dtype}. " - f"Value shape: {val_aval.dtype}. ") - return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State} -swap_p.def_effectful_abstract_eval(_swap_abstract_eval) - - -def _addupdate_abstract_eval(ref_aval: ShapedArrayRef, - val_aval: core.AbstractValue, - *idx: int): - if not isinstance(ref_aval, ShapedArrayRef): - raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") - val_aval = core.raise_to_shaped(val_aval) - assert isinstance(val_aval, core.ShapedArray) - expected_output_shape = ref_aval.shape[len(idx):] - if expected_output_shape != val_aval.shape: - raise ValueError("Invalid shape for `swap`. " - f"Ref shape: {ref_aval.shape}. " - f"Value shape: {val_aval.shape}. " - f"Indices: {idx}. ") - return [], {State} -addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) - -## Pretty printing for `get` and `swap` in jaxprs - -pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL, - foreground=pp.Color.GREEN) - -def _get_pp_rule(eqn, context, settings): - # Pretty prints `a = get x i` as `a <- x[i]` - y, = eqn.outvars - x, *idx = eqn.invars - idx = ','.join(core.pp_var(i, context) for i in idx) - lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) - return [lhs, pp.text(' <- '), pp_ref(pp.concat([ - pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') - ]))] -core.pp_eqn_rules[get_p] = _get_pp_rule - -def _swap_pp_rule(eqn, context, settings): - y, = eqn.outvars - x, v, *idx = eqn.invars - idx = ','.join(core.pp_var(i, context) for i in idx) - if type(y) is core.DropVar: - # In the case of a set (ignored return value), - # pretty print `_ = swap x v i` as `x[i] <- v` - del y - return [ - pp_ref(pp.concat([ - pp.text(core.pp_var(x, context)), - pp.text('['), pp.text(idx), pp.text(']') - ])), pp.text(' <- '), pp.text(core.pp_var(v, context))] - else: - # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` - x_i = pp.concat([pp.text(core.pp_var(x, context)), - pp.text('['), pp.text(idx), pp.text(']')]) - y = core.pp_vars([y], context, print_shapes=settings.print_shapes) - return [y, pp.text(', '), x_i, pp.text(' <- '), - x_i, pp.text(', '), pp.text(core.pp_var(v, context))] -core.pp_eqn_rules[swap_p] = _swap_pp_rule - -def _addupdate_pp_rule(eqn, context, settings): - # pretty-print ` = addupdate x i v` as `x[i] += v` - () = eqn.outvars - x, v, *idx = eqn.invars - idx = ','.join(core.pp_var(i, context) for i in idx) - return [ - pp_ref(pp.concat([ - pp.text(core.pp_var(x, context)), - pp.text('['), pp.text(idx), pp.text(']') - ])), pp.text(' += '), pp.text(core.pp_var(v, context))] -core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule - -## get/swap/addupdate JVP rules - -def _get_jvp(primals: List[Any], tangents: List[Any]): - ref_primal, *idx = primals - assert isinstance(ref_primal.aval, ShapedArrayRef) - ref_tangent, *_ = tangents - assert isinstance(ref_tangent.aval, ShapedArrayRef) - return ref_get(ref_primal, idx), ref_get(ref_tangent, idx) # type: ignore[arg-type] -ad.primitive_jvps[get_p] = _get_jvp - -def _swap_jvp(primals: List[Any], tangents: List[Any]): - ref_primal, x_primal, *idx = primals - assert isinstance(ref_primal.aval, ShapedArrayRef) - ref_tangent, x_tangent, *_ = tangents - assert isinstance(ref_tangent.aval, ShapedArrayRef) - x_tangent = ad_util.instantiate(x_tangent) - return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type] - ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type] -ad.primitive_jvps[swap_p] = _swap_jvp - -def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]): - ref_primal, x_primal, *idx = primals - ref_tangent, x_tangent, *_ = tangents - x_tangent = ad_util.instantiate(x_tangent) - addupdate_p.bind(ref_primal, x_primal, *idx) - addupdate_p.bind(ref_tangent, x_tangent, *idx) - return [], [] -ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule - -## get/swap/addupdate transpose rules - -def _get_transpose(g, ref, *idx): - # get transpose is addupdate - if type(g) is not ad_util.Zero: - ref_addupdate(ref, idx, g) - return [None] + [None] * len(idx) -ad.primitive_transposes[get_p] = _get_transpose - -def _swap_transpose(g, ref, x, *idx): - # swap transpose is swap - x_bar = ref_swap(ref, idx, ad_util.instantiate(g)) - return [None, x_bar] + [None] * len(idx) -ad.primitive_transposes[swap_p] = _swap_transpose - - -## Discharging state - -# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values -# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr -# into a "pure" jaxpr that takes in and outputs values and no longer has the -# `State` effect. - -def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]: - """Converts a jaxpr that takes in `Ref`s into one that doesn't.""" - in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype) - if type(v.aval) is ShapedArrayRef - else v.aval for v in jaxpr.invars] - eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, consts)) - new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) - return new_jaxpr, new_consts - -def _dynamic_index(x, idx): - if not idx: return x - ndim = len(x.shape) - starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx)) - sizes = (1,) * len(idx) + x.shape[len(idx):] - out = lax.dynamic_slice(x, starts, sizes) - return out.reshape(x.shape[len(idx):]) - -def _dynamic_update_index(x, idx, val): - if not idx: return val - ndim = len(x.shape) - starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx)) - update = val.reshape((1,) * len(idx) + x.shape[len(idx):]) - return lax.dynamic_update_slice(x, update, starts) - -def _eval_jaxpr_discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], - *args: Any): - env: Dict[core.Var, Any] = {} - - def read(v: core.Atom) -> Any: - if type(v) is core.Literal: - return v.val - assert isinstance(v, core.Var) - return env[v] - - def write(v: core.Var, val: Any) -> None: - env[v] = val - - map(write, jaxpr.constvars, consts) - # Here some args may correspond to `Ref` avals but they'll be treated like - # regular values in this interpreter. - map(write, jaxpr.invars, args) - - for eqn in jaxpr.eqns: - in_vals = map(read, eqn.invars) - if eqn.primitive is get_p: - # `y <- x[i]` becomes `y = ds x i` - x, *idx = in_vals - write(eqn.outvars[0], _dynamic_index(x, idx)) - elif eqn.primitive is swap_p: - # `z, x[i] <- x[i], val` becomes: - # z = ds x i - # x = dus x i val - x, val, *idx = in_vals - write(eqn.outvars[0], _dynamic_index(x, idx)) - assert isinstance(eqn.invars[0], core.Var) - write(eqn.invars[0], _dynamic_update_index(x, idx, val)) - elif eqn.primitive is addupdate_p: - # `x[i] += val` becomes: - # y = ds x i - # z = y + val - # x = dus x i z - x, val, *idx = in_vals - ans = _dynamic_update_index(x, idx, val + _dynamic_index(x, idx)) - assert isinstance(eqn.invars[0], core.Var) - write(eqn.invars[0], ans) - else: - # Default primitive rule, similar to `core.eval_jaxpr`. Note that here - # we assume any higher-order primitives inside of the jaxpr are *not* - # stateful. - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params) - if eqn.primitive.multiple_results: - map(write, eqn.outvars, ans) - else: - write(eqn.outvars[0], ans) - # By convention, we return the outputs of the jaxpr first and then the final - # values of the `Ref`s. Callers to this function should be able to split - # them up by looking at `len(jaxpr.outvars)`. - out_vals = map(read, jaxpr.outvars) - ref_vals = map( - read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef]) - return out_vals + ref_vals ## `for_loop` implementation diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py new file mode 100644 index 000000000..0447a5368 --- /dev/null +++ b/jax/_src/state/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for state.""" +from jax._src.state.types import ShapedArrayRef, StateEffect +from jax._src.state.primitives import (ref_get, ref_set, ref_swap, + ref_addupdate, get_p, swap_p, + addupdate_p) +from jax._src.state.discharge import discharge_state diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py new file mode 100644 index 000000000..13365a71e --- /dev/null +++ b/jax/_src/state/discharge.py @@ -0,0 +1,121 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for discharging state primitives.""" +from functools import partial + +from typing import Any, Dict, List, Sequence, Tuple + +from jax import core +from jax import lax +from jax import linear_util as lu +from jax.interpreters import partial_eval as pe +from jax._src.util import safe_map, safe_zip + +from jax._src.state.types import ShapedArrayRef +from jax._src.state.primitives import get_p, swap_p, addupdate_p + +## JAX utilities + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +## Discharging state + +# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values +# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr +# into a "pure" jaxpr that takes in and outputs values and no longer has the +# `StateEffect` effect. + +def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]: + """Converts a jaxpr that takes in `Ref`s into one that doesn't.""" + in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype) + if type(v.aval) is ShapedArrayRef + else v.aval for v in jaxpr.invars] + eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, consts)) + new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) + return new_jaxpr, new_consts + +def _dynamic_index(x, idx): + if not idx: return x + ndim = len(x.shape) + starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx)) + sizes = (1,) * len(idx) + x.shape[len(idx):] + out = lax.dynamic_slice(x, starts, sizes) + return out.reshape(x.shape[len(idx):]) + +def _dynamic_update_index(x, idx, val): + if not idx: return val + ndim = len(x.shape) + starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx)) + update = val.reshape((1,) * len(idx) + x.shape[len(idx):]) + return lax.dynamic_update_slice(x, update, starts) + +def _eval_jaxpr_discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], + *args: Any): + env: Dict[core.Var, Any] = {} + + def read(v: core.Atom) -> Any: + if type(v) is core.Literal: + return v.val + assert isinstance(v, core.Var) + return env[v] + + def write(v: core.Var, val: Any) -> None: + env[v] = val + + map(write, jaxpr.constvars, consts) + # Here some args may correspond to `Ref` avals but they'll be treated like + # regular values in this interpreter. + map(write, jaxpr.invars, args) + + for eqn in jaxpr.eqns: + in_vals = map(read, eqn.invars) + if eqn.primitive is get_p: + # `y <- x[i]` becomes `y = ds x i` + x, *idx = in_vals + write(eqn.outvars[0], _dynamic_index(x, idx)) + elif eqn.primitive is swap_p: + # `z, x[i] <- x[i], val` becomes: + # z = ds x i + # x = dus x i val + x, val, *idx = in_vals + write(eqn.outvars[0], _dynamic_index(x, idx)) + assert isinstance(eqn.invars[0], core.Var) + write(eqn.invars[0], _dynamic_update_index(x, idx, val)) + elif eqn.primitive is addupdate_p: + # `x[i] += val` becomes: + # y = ds x i + # z = y + val + # x = dus x i z + x, val, *idx = in_vals + ans = _dynamic_update_index(x, idx, val + _dynamic_index(x, idx)) + assert isinstance(eqn.invars[0], core.Var) + write(eqn.invars[0], ans) + else: + # Default primitive rule, similar to `core.eval_jaxpr`. Note that here + # we assume any higher-order primitives inside of the jaxpr are *not* + # stateful. + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params) + if eqn.primitive.multiple_results: + map(write, eqn.outvars, ans) + else: + write(eqn.outvars[0], ans) + # By convention, we return the outputs of the jaxpr first and then the final + # values of the `Ref`s. Callers to this function should be able to split + # them up by looking at `len(jaxpr.outvars)`. + out_vals = map(read, jaxpr.outvars) + ref_vals = map( + read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef]) + return out_vals + ref_vals diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py new file mode 100644 index 000000000..211d8a634 --- /dev/null +++ b/jax/_src/state/primitives.py @@ -0,0 +1,254 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for state primitives.""" +from functools import partial + +from typing import Any, Generic, List, Tuple, TypeVar + +from jax import core +from jax._src import ad_util +from jax._src import pretty_printer as pp +from jax._src.util import safe_map, safe_zip +from jax.interpreters import ad +import jax.numpy as jnp + +from jax._src.state.types import ShapedArrayRef, StateEffect + +## General utilities + +Array = Any +T = TypeVar('T') +class Ref(Generic[T]): pass + +## JAX utilities + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +## get/swap/addupdate implementations + +# `get` reads a value from a `Ref` type, a.k.a.: +# a = get_p.bind(x) +# or we can read using indices: +# a = get_p.bind(x, 0, 1) +# Staging out `a = get_p.bind(x)` where the aval of `x` is +# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like +# a:f32[3] <- x[] +get_p = core.Primitive("get") + +def _get_impl(ref: Ref, *idx: int): + del ref, idx + raise ValueError("Cannot run stateful primitive.") +get_p.def_impl(_get_impl) + +def ref_get(ref: Ref, idx: Tuple[int]) -> Array: + """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" + idx = map(jnp.int32, idx) + return get_p.bind(ref, *idx) + +# `swap` mutates a `Ref`, setting its value and returns its previous value. +# b = swap_p.bind(x, a) +# It generalizes the setting operation for a `Ref` as we can ignore the return +# value: +# _ = swap_p.bind(x, a) +# `swap_p` also takes in index arguments following the value, i.e.: +# _ = swap_p.bind(x, a, 0, 1) +# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is +# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is +# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like +# b:f32[3], x:Ref{f32[3]} <- x, a +# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is +# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is +# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j` +# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like +# x:Ref{f32[3]}[i, j] <- a +swap_p = core.Primitive("swap") + +def _swap_impl(ref: Ref, value: Array, *idx: int): + del ref, value, idx + raise ValueError("Cannot run stateful primitive.") +swap_p.def_impl(_swap_impl) + +def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array: + """Sets a `Ref`'s value and returns the original value.""" + idx = map(jnp.int32, idx) + return swap_p.bind(ref, value, *idx) + +def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None: + """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" + ref_swap(ref, idx, value) + +# `addupdate_p` mutates a `Ref`, adding a value to its existing value. +# Semantically, +# ``` +# addupdate ref a *idx +# ``` +# is equivalent to +# ``` +# b = get ref *idx +# c = add b x +# _ = swap ref c *idx +# ``` +addupdate_p = core.Primitive('addupdate') +addupdate_p.multiple_results = True + +def _addupdate_impl(ref: Ref, value: Array, *idx: int): + del ref, idx, value + raise ValueError("Can't evaluate `addupdate` outside a stateful context.") +addupdate_p.def_impl(_addupdate_impl) + +def ref_addupdate(ref: Ref, idx: Tuple[int], x: Array) -> None: + """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" + return addupdate_p.bind(ref, x, *idx) + +## get/set/addupdate abstract evaluation rules + +def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx: int): + if not isinstance(ref_aval, ShapedArrayRef): + raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") + return (core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), + {StateEffect}) +get_p.def_effectful_abstract_eval(_get_abstract_eval) + + +def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue, + *idx: int): + if not isinstance(ref_aval, ShapedArrayRef): + raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") + val_aval = core.raise_to_shaped(val_aval) + assert isinstance(val_aval, core.ShapedArray) + expected_output_shape = ref_aval.shape[len(idx):] + if expected_output_shape != val_aval.shape: + raise ValueError("Invalid shape for `swap`. " + f"Ref shape: {ref_aval.shape}. " + f"Value shape: {val_aval.shape}. " + f"Indices: {idx}. ") + if ref_aval.dtype != val_aval.dtype: + raise ValueError("Invalid dtype for `swap`. " + f"Ref dtype: {ref_aval.dtype}. " + f"Value shape: {val_aval.dtype}. ") + return (core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), + {StateEffect}) +swap_p.def_effectful_abstract_eval(_swap_abstract_eval) + + +def _addupdate_abstract_eval(ref_aval: ShapedArrayRef, + val_aval: core.AbstractValue, + *idx: int): + if not isinstance(ref_aval, ShapedArrayRef): + raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") + val_aval = core.raise_to_shaped(val_aval) + assert isinstance(val_aval, core.ShapedArray) + expected_output_shape = ref_aval.shape[len(idx):] + if expected_output_shape != val_aval.shape: + raise ValueError("Invalid shape for `swap`. " + f"Ref shape: {ref_aval.shape}. " + f"Value shape: {val_aval.shape}. " + f"Indices: {idx}. ") + return [], {StateEffect} +addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) + +## Pretty printing for `get` and `swap` in jaxprs + +pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL, + foreground=pp.Color.GREEN) + +def _get_pp_rule(eqn, context, settings): + # Pretty prints `a = get x i` as `x[i] <- a` + y, = eqn.outvars + x, *idx = eqn.invars + idx = ','.join(core.pp_var(i, context) for i in idx) + lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) + return [lhs, pp.text(' <- '), pp_ref(pp.concat([ + pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') + ]))] +core.pp_eqn_rules[get_p] = _get_pp_rule + +def _swap_pp_rule(eqn, context, settings): + y, = eqn.outvars + x, v, *idx = eqn.invars + idx = ','.join(core.pp_var(i, context) for i in idx) + if type(y) is core.DropVar: + # In the case of a set (ignored return value), + # pretty print `_ = swap x v i` as `x[i] <- v` + del y + return [ + pp_ref(pp.concat([ + pp.text(core.pp_var(x, context)), + pp.text('['), pp.text(idx), pp.text(']') + ])), pp.text(' <- '), pp.text(core.pp_var(v, context))] + else: + # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` + x_i = pp.concat([pp.text(core.pp_var(x, context)), + pp.text('['), pp.text(idx), pp.text(']')]) + y = core.pp_vars([y], context, print_shapes=settings.print_shapes) + return [y, pp.text(', '), x_i, pp.text(' <- '), + x_i, pp.text(', '), pp.text(core.pp_var(v, context))] +core.pp_eqn_rules[swap_p] = _swap_pp_rule + +def _addupdate_pp_rule(eqn, context, settings): + # pretty-print ` = addupdate x i v` as `x[i] += v` + () = eqn.outvars + x, v, *idx = eqn.invars + idx = ','.join(core.pp_var(i, context) for i in idx) + return [ + pp_ref(pp.concat([ + pp.text(core.pp_var(x, context)), + pp.text('['), pp.text(idx), pp.text(']') + ])), pp.text(' += '), pp.text(core.pp_var(v, context))] +core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule + +## get/swap/addupdate JVP rules + +def _get_jvp(primals: List[Any], tangents: List[Any]): + ref_primal, *idx = primals + assert isinstance(ref_primal.aval, ShapedArrayRef) + ref_tangent, *_ = tangents + assert isinstance(ref_tangent.aval, ShapedArrayRef) + return ref_get(ref_primal, idx), ref_get(ref_tangent, idx) # type: ignore[arg-type] +ad.primitive_jvps[get_p] = _get_jvp + +def _swap_jvp(primals: List[Any], tangents: List[Any]): + ref_primal, x_primal, *idx = primals + assert isinstance(ref_primal.aval, ShapedArrayRef) + ref_tangent, x_tangent, *_ = tangents + assert isinstance(ref_tangent.aval, ShapedArrayRef) + x_tangent = ad_util.instantiate(x_tangent) + return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type] + ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type] +ad.primitive_jvps[swap_p] = _swap_jvp + +def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]): + ref_primal, x_primal, *idx = primals + ref_tangent, x_tangent, *_ = tangents + x_tangent = ad_util.instantiate(x_tangent) + addupdate_p.bind(ref_primal, x_primal, *idx) + addupdate_p.bind(ref_tangent, x_tangent, *idx) + return [], [] +ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule + +## get/swap/addupdate transpose rules + +def _get_transpose(g, ref, *idx): + # get transpose is addupdate + if type(g) is not ad_util.Zero: + ref_addupdate(ref, idx, g) + return [None] + [None] * len(idx) +ad.primitive_transposes[get_p] = _get_transpose + +def _swap_transpose(g, ref, x, *idx): + # swap transpose is swap + x_bar = ref_swap(ref, idx, ad_util.instantiate(g)) + return [None, x_bar] + [None] * len(idx) +ad.primitive_transposes[swap_p] = _swap_transpose diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py new file mode 100644 index 000000000..a22c35138 --- /dev/null +++ b/jax/_src/state/types.py @@ -0,0 +1,107 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for state types.""" +from functools import partial + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from jax import api_util +from jax import core +from jax import linear_util as lu +from jax import tree_util +from jax._src import ad_util +from jax._src import device_array +from jax._src import dispatch +from jax._src import pretty_printer as pp +from jax._src.lib import xla_bridge, xla_client +from jax._src.util import safe_map, safe_zip, split_list +from jax.interpreters import ad +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +import numpy as np + +xc = xla_client +xb = xla_bridge + +## JAX utilities + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Array = Any + +class _StateEffect: + def __repr__(self): + return "State" + __str__ = __repr__ +StateEffect = _StateEffect() + +# ## `Ref`s + +# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. +# A `ShapedArrayRef` is a abstract value for mutable containers of array types +class ShapedArrayRef(core.AbstractValue): + __slots__ = ["shape", "dtype"] + + def __init__(self, shape, dtype): + self.shape = shape + self.dtype = dtype + + def join(self, other): + assert core.symbolic_equal_shape(self.shape, other.shape) + assert self.dtype == other.dtype + return self + + @core.aval_method + @staticmethod + def get(tracer, idx=()): + from jax._src.state.primitives import ref_get + return ref_get(tracer, idx) + + @core.aval_method + @staticmethod + def set(tracer, value, idx=()): + from jax._src.state.primitives import ref_set + return ref_set(tracer, idx, value) + + def _getitem(self, tracer, idx) -> Array: + if not isinstance(idx, tuple): + idx = idx, + from jax._src.state.primitives import ref_get + return ref_get(tracer, idx) + + def _setitem(self, tracer, idx, value) -> None: + if not isinstance(idx, tuple): + idx = idx, + from jax._src.state.primitives import ref_set + return ref_set(tracer, idx, value) + + def __repr__(self) -> str: + a = core.ShapedArray(self.shape, self.dtype) + return f'Ref{{{a.str_short()}}}' + + def at_least_vspace(self): + return self + + def __eq__(self, other): + return (type(self) is type(other) + and self.dtype == other.dtype and self.shape == other.shape) + + def __hash__(self): + return hash((self.shape, self.dtype)) + + +core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 5e5dc6a8c..79c7e2fc7 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -29,13 +29,11 @@ import jax from jax import core from jax.errors import UnexpectedTracerError from jax import lax -from jax import linear_util as lu from jax import random from jax._src import test_util as jtu from jax import tree_util from jax._src.util import unzip2 from jax.experimental import maps -from jax.interpreters import partial_eval as pe from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp @@ -2563,320 +2561,6 @@ class LaxControlFlowTest(jtu.JaxTestCase): class ForLoopTest(jtu.JaxTestCase): - def test_cant_eval_get_primitive(self): - with self.assertRaises(ValueError): - for_loop.get_p.bind(jnp.ones(5)) - - def test_cant_eval_swap_primitive(self): - with self.assertRaises(ValueError): - for_loop.swap_p.bind(jnp.ones(5), jnp.zeros(5)) - - def test_cant_eval_addupdate_primitive(self): - with self.assertRaises(ValueError): - for_loop.addupdate_p.bind(jnp.ones(5), jnp.zeros(5)) - - def test_get_abstract_eval(self): - ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32) - out_aval, effect = for_loop.get_p.abstract_eval(ref_aval, 0) - self.assertSetEqual(effect, {for_loop.State}) - self.assertTupleEqual(out_aval.shape, (2, 3)) - self.assertEqual(out_aval.dtype, jnp.float32) - - def test_get_abstract_aval_must_take_in_refs(self): - with self.assertRaises(ValueError): - for_loop.get_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32)) - - def test_swap_abstract_eval(self): - ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32) - val_aval = core.ShapedArray((2, 3), jnp.float32) - out_aval, effect = for_loop.swap_p.abstract_eval(ref_aval, val_aval, 0) - self.assertSetEqual(effect, {for_loop.State}) - self.assertTupleEqual(out_aval.shape, (2, 3)) - self.assertEqual(out_aval.dtype, jnp.float32) - - def test_swap_abstract_eval_must_take_in_refs(self): - with self.assertRaises(ValueError): - for_loop.swap_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32), - core.ShapedArray((1, 2, 3), jnp.float32)) - - def test_swap_checks_for_correct_shapes(self): - with self.assertRaises(ValueError): - for_loop.swap_p.abstract_eval( - for_loop.ShapedArrayRef((1, 2, 3), jnp.float32), - core.ShapedArray((2, 3), jnp.float32)) - with self.assertRaises(ValueError): - for_loop.swap_p.abstract_eval( - for_loop.ShapedArrayRef((1, 2, 3), jnp.float32), - core.ShapedArray((1, 2, 3, 4), jnp.float32)) - for_loop.swap_p.abstract_eval( - for_loop.ShapedArrayRef((1, 2, 3), jnp.float32), - core.ShapedArray((2, 3), jnp.float32), 1) - - def test_addupdate_abstract_eval(self): - ref_aval = for_loop.ShapedArrayRef((1, 2, 3), jnp.float32) - val_aval = core.ShapedArray((2, 3), jnp.float32) - out_avals, effect = for_loop.addupdate_p.abstract_eval(ref_aval, val_aval, - 0) - self.assertSetEqual(effect, {for_loop.State}) - self.assertListEqual(out_avals, []) - - def test_addupdate_abstract_eval_must_take_in_refs(self): - with self.assertRaises(ValueError): - for_loop.addupdate_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32), - core.ShapedArray((1, 2, 3), jnp.float32)) - - def test_addupdate_checks_for_correct_shapes(self): - with self.assertRaises(ValueError): - for_loop.addupdate_p.abstract_eval( - for_loop.ShapedArrayRef((1, 2, 3), jnp.float32), - core.ShapedArray((2, 3), jnp.float32)) - with self.assertRaises(ValueError): - for_loop.addupdate_p.abstract_eval( - for_loop.ShapedArrayRef((1, 2, 3), jnp.float32), - core.ShapedArray((1, 2, 3, 4), jnp.float32)) - for_loop.addupdate_p.abstract_eval( - for_loop.ShapedArrayRef((1, 2, 3), jnp.float32), - core.ShapedArray((2, 3), jnp.float32), 1) - - def test_can_represent_get_and_swap_in_jaxprs(self): - - def body(x): - x[()] = jnp.int32(1) - x[()] = jnp.int32(2) - return (x[()],) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) - self.assertLen(consts, 0) - self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) - self.assertEqual(jaxpr.eqns[0].primitive, for_loop.swap_p) - self.assertEqual(jaxpr.eqns[1].primitive, for_loop.swap_p) - self.assertEqual(jaxpr.eqns[2].primitive, for_loop.get_p) - - def test_can_represent_addupdate_in_jaxprs(self): - - def body(x): - for_loop.ref_addupdate(x, (), jnp.int32(1)) - return (x[()],) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) - self.assertLen(consts, 0) - self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) - self.assertEqual(jaxpr.eqns[0].primitive, for_loop.addupdate_p) - - def test_get_custom_pretty_printing_rule(self): - def body(x_ref): - x = x_ref[()] - return [x] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) - self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) - - def test_set_custom_pretty_printing_rule(self): - def body(x_ref): - x_ref[()] = jnp.int32(2) - return [] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) - self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) - - def test_swap_custom_pretty_printing_rule(self): - def body(x_ref): - x = for_loop.ref_swap(x_ref, (), jnp.int32(2)) - return [x] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) - self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) - - def test_addupdate_custom_pretty_printing_rule(self): - def body(x_ref): - for_loop.ref_addupdate(x_ref, (), jnp.int32(2)) - return [] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) - self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) - - def test_get_jvp(self): - - def f(r): - x = r[()] - return jnp.cos(x) - - def g(r, rdot): - return jax.jvp(f, (r,), (rdot,)) - - in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')), - for_loop.ShapedArrayRef((), jnp.dtype('float32'))] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) - self.assertEqual(jaxpr.eqns[0].primitive, for_loop.get_p) - self.assertEqual(jaxpr.eqns[1].primitive, for_loop.get_p) - - def test_swap_jvp(self): - - def f(a): - x = a[()] - a[()] = jnp.sin(x) - return a[()] - - def g(r, rdot): - return jax.jvp(f, (r,), (rdot,)) - - in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')), - for_loop.ShapedArrayRef((), jnp.dtype('float32'))] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) - self.assertEqual(jaxpr.eqns[0].primitive, for_loop.get_p) - self.assertEqual(jaxpr.eqns[1].primitive, for_loop.get_p) - self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) - self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p) - self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p) - self.assertEqual(jaxpr.eqns[5].primitive, for_loop.swap_p) - self.assertEqual(jaxpr.eqns[6].primitive, for_loop.swap_p) - - def test_addupdate_jvp(self): - - def f(a): - for_loop.ref_addupdate(a, (), 1.) - return a[()] - - def g(r, rdot): - return jax.jvp(f, (r,), (rdot,)) - - in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')), - for_loop.ShapedArrayRef((), jnp.dtype('float32'))] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) - self.assertEqual(jaxpr.eqns[0].primitive, for_loop.addupdate_p) - self.assertEqual(jaxpr.eqns[1].primitive, for_loop.addupdate_p) - self.assertEqual(jaxpr.eqns[2].primitive, for_loop.get_p) - self.assertEqual(jaxpr.eqns[3].primitive, for_loop.get_p) - - def test_discharge_get(self): - def f(a_ref): - a = for_loop.ref_get(a_ref, ()) - return [a + 1] - in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - # Discharging should just turn this into a jaxpr that just adds 1. - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 1) - self.assertLen(discharged_jaxpr.outvars, 2) - self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p) - # Should be able to evaluate this jaxpr - self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (), - jnp.float32(1.)), [2., 1.]) - - def test_discharge_get_with_slice(self): - def f(a_ref): - a = for_loop.ref_get(a_ref, (0, 1)) - return [a + 1] - in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - # Discharging should just turn this into a jaxpr that just adds 1. - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 1) - self.assertLen(discharged_jaxpr.outvars, 2) - self.assertIn(lax.dynamic_slice_p, - set(eqn.primitive for eqn in discharged_jaxpr.eqns)) - # Should be able to evaluate this jaxpr - inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) - outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval) - self.assertTrue((outval == inval[0, 1] + 1).all()) - self.assertTrue((refval == inval).all()) - - def test_discharge_set(self): - def f(a_ref, b): - for_loop.ref_set(a_ref, (), b + 1) - return [] - in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')), - core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - # Discharging should just turn this into a jaxpr that ignores the first - # value and returns second value plus 1. - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 2) - self.assertLen(discharged_jaxpr.outvars, 1) - self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.), - jnp.float32(1.))[0], 2.) - self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.), - jnp.float32(1.))[0], 2.) - - def test_discharge_set_with_slice(self): - def f(a_ref): - for_loop.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) - return [] - in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - # Discharging should just turn this into a jaxpr that just adds 1. - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 1) - self.assertLen(discharged_jaxpr.outvars, 1) - self.assertIn(lax.dynamic_update_slice_p, - set(eqn.primitive for eqn in discharged_jaxpr.eqns)) - self.assertIn(lax.dynamic_slice_p, - set(eqn.primitive for eqn in discharged_jaxpr.eqns)) - # Should be able to evaluate this jaxpr - inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) - refval, = core.eval_jaxpr(discharged_jaxpr, (), inval) - self.assertTrue((refval == inval.at[0, 1].set(1.)).all()) - - def test_discharge_addupdate(self): - def f(a_ref, b): - for_loop.ref_addupdate(a_ref, (), b + 1) - return [] - in_avals = [for_loop.ShapedArrayRef((), jnp.dtype('float32')), - core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - # Discharging should just turn this into a jaxpr that adds the first value, - # second value, and 1. - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 2) - self.assertLen(discharged_jaxpr.outvars, 1) - self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.), - jnp.float32(1.))[0], 2.) - self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.), - jnp.float32(1.))[0], 4.) - - def test_discharge_addupdate_with_slice(self): - def f(a_ref): - for_loop.ref_addupdate(a_ref, (0, 1), - jnp.ones(2, dtype=jnp.dtype('float32'))) - return [] - in_avals = [for_loop.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 1) - self.assertLen(discharged_jaxpr.outvars, 1) - self.assertIn(lax.dynamic_update_slice_p, - set(eqn.primitive for eqn in discharged_jaxpr.eqns)) - self.assertIn(lax.add_p, - set(eqn.primitive for eqn in discharged_jaxpr.eqns)) - self.assertIn(lax.dynamic_slice_p, - set(eqn.primitive for eqn in discharged_jaxpr.eqns)) - inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) - refval, = core.eval_jaxpr(discharged_jaxpr, (), inval) - self.assertTrue((refval == inval.at[0, 1].add(1.)).all()) - - def test_discharge_jaxpr_with_multiple_outputs(self): - def f(a_ref): - a = for_loop.ref_get(a_ref, ()) - b = a + 1 - return [a, b] - in_avals = [for_loop.ShapedArrayRef((4,), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), - in_avals) - discharged_jaxpr, _ = for_loop.discharge_state(stateful_jaxpr, consts) - self.assertLen(discharged_jaxpr.invars, 1) - self.assertLen(discharged_jaxpr.outvars, 3) - inval = jnp.arange(4., dtype=jnp.float32) - a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval) - self.assertTrue((a == inval).all()) - self.assertTrue((b == inval + 1).all()) - self.assertTrue((refval == inval).all()) - def test_for_loop_impl_trivial(self): out = for_loop.for_loop(5, lambda i, _: None, None) self.assertEqual(out, None) diff --git a/tests/state_test.py b/tests/state_test.py new file mode 100644 index 000000000..60afaf755 --- /dev/null +++ b/tests/state_test.py @@ -0,0 +1,348 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +import jax +from jax import core +from jax import lax +from jax import linear_util as lu +from jax.config import config +from jax.interpreters import partial_eval as pe +from jax._src import test_util as jtu +import jax.numpy as jnp + +from jax._src import state + +config.parse_flags_with_absl() + +class StatePrimitivesTest(jtu.JaxTestCase): + + def test_cant_eval_get_primitive(self): + with self.assertRaises(ValueError): + state.get_p.bind(jnp.ones(5)) + + def test_cant_eval_swap_primitive(self): + with self.assertRaises(ValueError): + state.swap_p.bind(jnp.ones(5), jnp.zeros(5)) + + def test_cant_eval_addupdate_primitive(self): + with self.assertRaises(ValueError): + state.addupdate_p.bind(jnp.ones(5), jnp.zeros(5)) + + def test_get_abstract_eval(self): + ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32) + out_aval, effect = state.get_p.abstract_eval(ref_aval, 0) + self.assertSetEqual(effect, {state.StateEffect}) + self.assertTupleEqual(out_aval.shape, (2, 3)) + self.assertEqual(out_aval.dtype, jnp.float32) + + def test_get_abstract_aval_must_take_in_refs(self): + with self.assertRaises(ValueError): + state.get_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32)) + + def test_swap_abstract_eval(self): + ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32) + val_aval = core.ShapedArray((2, 3), jnp.float32) + out_aval, effect = state.swap_p.abstract_eval(ref_aval, val_aval, 0) + self.assertSetEqual(effect, {state.StateEffect}) + self.assertTupleEqual(out_aval.shape, (2, 3)) + self.assertEqual(out_aval.dtype, jnp.float32) + + def test_swap_abstract_eval_must_take_in_refs(self): + with self.assertRaises(ValueError): + state.swap_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32), + core.ShapedArray((1, 2, 3), jnp.float32)) + + def test_swap_checks_for_correct_shapes(self): + with self.assertRaises(ValueError): + state.swap_p.abstract_eval( + state.ShapedArrayRef((1, 2, 3), jnp.float32), + core.ShapedArray((2, 3), jnp.float32)) + with self.assertRaises(ValueError): + state.swap_p.abstract_eval( + state.ShapedArrayRef((1, 2, 3), jnp.float32), + core.ShapedArray((1, 2, 3, 4), jnp.float32)) + state.swap_p.abstract_eval( + state.ShapedArrayRef((1, 2, 3), jnp.float32), + core.ShapedArray((2, 3), jnp.float32), 1) + + def test_addupdate_abstract_eval(self): + ref_aval = state.ShapedArrayRef((1, 2, 3), jnp.float32) + val_aval = core.ShapedArray((2, 3), jnp.float32) + out_avals, effect = state.addupdate_p.abstract_eval(ref_aval, val_aval, + 0) + self.assertSetEqual(effect, {state.StateEffect}) + self.assertListEqual(out_avals, []) + + def test_addupdate_abstract_eval_must_take_in_refs(self): + with self.assertRaises(ValueError): + state.addupdate_p.abstract_eval(core.ShapedArray((1, 2, 3), jnp.float32), + core.ShapedArray((1, 2, 3), jnp.float32)) + + def test_addupdate_checks_for_correct_shapes(self): + with self.assertRaises(ValueError): + state.addupdate_p.abstract_eval( + state.ShapedArrayRef((1, 2, 3), jnp.float32), + core.ShapedArray((2, 3), jnp.float32)) + with self.assertRaises(ValueError): + state.addupdate_p.abstract_eval( + state.ShapedArrayRef((1, 2, 3), jnp.float32), + core.ShapedArray((1, 2, 3, 4), jnp.float32)) + state.addupdate_p.abstract_eval( + state.ShapedArrayRef((1, 2, 3), jnp.float32), + core.ShapedArray((2, 3), jnp.float32), 1) + + def test_can_represent_get_and_swap_in_jaxprs(self): + + def body(x): + x[()] = jnp.int32(1) + x[()] = jnp.int32(2) + return (x[()],) + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)]) + self.assertLen(consts, 0) + self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) + self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p) + self.assertEqual(jaxpr.eqns[1].primitive, state.swap_p) + self.assertEqual(jaxpr.eqns[2].primitive, state.get_p) + + def test_can_represent_addupdate_in_jaxprs(self): + + def body(x): + state.ref_addupdate(x, (), jnp.int32(1)) + return (x[()],) + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)]) + self.assertLen(consts, 0) + self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) + self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p) + + def test_get_custom_pretty_printing_rule(self): + def body(x_ref): + x = x_ref[()] + return [x] + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)]) + self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) + + def test_set_custom_pretty_printing_rule(self): + def body(x_ref): + x_ref[()] = jnp.int32(2) + return [] + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)]) + self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) + + def test_swap_custom_pretty_printing_rule(self): + def body(x_ref): + x = state.ref_swap(x_ref, (), jnp.int32(2)) + return [x] + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)]) + self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) + + def test_addupdate_custom_pretty_printing_rule(self): + def body(x_ref): + state.ref_addupdate(x_ref, (), jnp.int32(2)) + return [] + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)]) + self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) + + def test_get_jvp(self): + + def f(r): + x = r[()] + return jnp.cos(x) + + def g(r, rdot): + return jax.jvp(f, (r,), (rdot,)) + + in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')), + state.ShapedArrayRef((), jnp.dtype('float32'))] + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + self.assertEqual(jaxpr.eqns[0].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[1].primitive, state.get_p) + + def test_swap_jvp(self): + + def f(a): + x = a[()] + a[()] = jnp.sin(x) + return a[()] + + def g(r, rdot): + return jax.jvp(f, (r,), (rdot,)) + + in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')), + state.ShapedArrayRef((), jnp.dtype('float32'))] + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + self.assertEqual(jaxpr.eqns[0].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[1].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) + self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p) + self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p) + self.assertEqual(jaxpr.eqns[5].primitive, state.swap_p) + self.assertEqual(jaxpr.eqns[6].primitive, state.swap_p) + + def test_addupdate_jvp(self): + + def f(a): + state.ref_addupdate(a, (), 1.) + return a[()] + + def g(r, rdot): + return jax.jvp(f, (r,), (rdot,)) + + in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')), + state.ShapedArrayRef((), jnp.dtype('float32'))] + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p) + self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p) + self.assertEqual(jaxpr.eqns[2].primitive, state.get_p) + self.assertEqual(jaxpr.eqns[3].primitive, state.get_p) + +class StateDischargeTest(jtu.JaxTestCase): + + def test_discharge_get(self): + def f(a_ref): + a = state.ref_get(a_ref, ()) + return [a + 1] + in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + # Discharging should just turn this into a jaxpr that just adds 1. + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 2) + self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p) + # Should be able to evaluate this jaxpr + self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (), + jnp.float32(1.)), [2., 1.]) + + def test_discharge_get_with_slice(self): + def f(a_ref): + a = state.ref_get(a_ref, (0, 1)) + return [a + 1] + in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + # Discharging should just turn this into a jaxpr that just adds 1. + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 2) + self.assertIn(lax.dynamic_slice_p, + set(eqn.primitive for eqn in discharged_jaxpr.eqns)) + # Should be able to evaluate this jaxpr + inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) + outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval) + self.assertTrue((outval == inval[0, 1] + 1).all()) + self.assertTrue((refval == inval).all()) + + def test_discharge_set(self): + def f(a_ref, b): + state.ref_set(a_ref, (), b + 1) + return [] + in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')), + core.ShapedArray((), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + # Discharging should just turn this into a jaxpr that ignores the first + # value and returns second value plus 1. + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 2) + self.assertLen(discharged_jaxpr.outvars, 1) + self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.), + jnp.float32(1.))[0], 2.) + self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.), + jnp.float32(1.))[0], 2.) + + def test_discharge_set_with_slice(self): + def f(a_ref): + state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) + return [] + in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + # Discharging should just turn this into a jaxpr that just adds 1. + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 1) + self.assertIn(lax.dynamic_update_slice_p, + set(eqn.primitive for eqn in discharged_jaxpr.eqns)) + self.assertIn(lax.dynamic_slice_p, + set(eqn.primitive for eqn in discharged_jaxpr.eqns)) + # Should be able to evaluate this jaxpr + inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) + refval, = core.eval_jaxpr(discharged_jaxpr, (), inval) + self.assertTrue((refval == inval.at[0, 1].set(1.)).all()) + + def test_discharge_addupdate(self): + def f(a_ref, b): + state.ref_addupdate(a_ref, (), b + 1) + return [] + in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')), + core.ShapedArray((), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + # Discharging should just turn this into a jaxpr that adds the first value, + # second value, and 1. + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 2) + self.assertLen(discharged_jaxpr.outvars, 1) + self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.), + jnp.float32(1.))[0], 2.) + self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.), + jnp.float32(1.))[0], 4.) + + def test_discharge_addupdate_with_slice(self): + def f(a_ref): + state.ref_addupdate(a_ref, (0, 1), + jnp.ones(2, dtype=jnp.dtype('float32'))) + return [] + in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 1) + self.assertIn(lax.dynamic_update_slice_p, + set(eqn.primitive for eqn in discharged_jaxpr.eqns)) + self.assertIn(lax.add_p, + set(eqn.primitive for eqn in discharged_jaxpr.eqns)) + self.assertIn(lax.dynamic_slice_p, + set(eqn.primitive for eqn in discharged_jaxpr.eqns)) + inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) + refval, = core.eval_jaxpr(discharged_jaxpr, (), inval) + self.assertTrue((refval == inval.at[0, 1].add(1.)).all()) + + def test_discharge_jaxpr_with_multiple_outputs(self): + def f(a_ref): + a = state.ref_get(a_ref, ()) + b = a + 1 + return [a, b] + in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))] + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 3) + inval = jnp.arange(4., dtype=jnp.float32) + a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval) + self.assertTrue((a == inval).all()) + self.assertTrue((b == inval + 1).all()) + self.assertTrue((refval == inval).all()) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())