Split State effect into Read/Write/Accum effects and tie them to Ref avals

This commit is contained in:
Sharad Vikram 2022-09-06 13:26:41 -07:00
parent b3393e3b60
commit b6c3b9df19
6 changed files with 40 additions and 17 deletions

View File

@ -51,7 +51,9 @@ T = TypeVar('T')
class Ref(Generic[T]): pass
Array = Any
StateEffect = state.StateEffect
ReadEffect = state.ReadEffect
WriteEffect = state.WriteEffect
AccumEffect = state.AccumEffect
ShapedArrayRef = state.ShapedArrayRef
ref_set = state.ref_set
ref_get = state.ref_get

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state."""
from jax._src.state.types import ShapedArrayRef, StateEffect
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
AccumEffect)
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
ref_addupdate, get_p, swap_p,
addupdate_p)

View File

@ -27,7 +27,8 @@ from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
from jax._src.state.types import ShapedArrayRef, StateEffect
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
AccumEffect)
## General utilities
@ -155,7 +156,7 @@ def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx, indexed_dims):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
idx_shapes = tuple(i.shape for i in idx)
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
return (core.ShapedArray(shape, ref_aval.dtype), {StateEffect})
return (core.ShapedArray(shape, ref_aval.dtype), {ReadEffect(ref_aval)})
get_p.def_effectful_abstract_eval(_get_abstract_eval)
@ -182,7 +183,7 @@ def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
return (core.ShapedArray(expected_output_shape, ref_aval.dtype),
{StateEffect})
{WriteEffect(ref_aval)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
@ -209,7 +210,7 @@ def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
return [], {StateEffect}
return [], {AccumEffect(ref_aval)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state types."""
from __future__ import annotations
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@ -44,11 +45,29 @@ zip, unsafe_zip = safe_zip, zip
Array = Any
class _StateEffect:
def __repr__(self):
return "State"
__str__ = __repr__
StateEffect = _StateEffect()
class RefEffect:
def __init__(self, ref_aval: ShapedArrayRef):
self.ref_aval = ref_aval
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.ref_aval is other.ref_aval
def __hash__(self):
return hash((self.__class__, self.ref_aval))
class ReadEffect(RefEffect):
def __str__(self):
return f"Read<{self.ref_aval}>"
class WriteEffect(RefEffect):
def __str__(self):
return f"Write<{self.ref_aval}>"
class AccumEffect(RefEffect):
def __str__(self):
return f"Accum<{self.ref_aval}>"
# ## `Ref`s

View File

@ -317,9 +317,9 @@ jax_test(
name = "lax_control_flow_test",
srcs = ["lax_control_flow_test.py"],
shard_count = {
"cpu": 20,
"gpu": 20,
"tpu": 20,
"cpu": 30,
"gpu": 30,
"tpu": 30,
"iree": 10,
},
)

View File

@ -97,7 +97,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval])
self.assertSetEqual(jaxpr.effects, {state.StateEffect})
self.assertSetEqual(jaxpr.effects, {state.ReadEffect(ref_aval)})
self.assertLen(out_avals, 1)
out_aval, = out_avals
self.assertIsInstance(out_aval, core.ShapedArray)
@ -163,7 +163,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects, {state.StateEffect})
self.assertSetEqual(jaxpr.effects, {state.WriteEffect(ref_aval)})
self.assertLen(out_avals, 1)
out_aval, = out_avals
self.assertIsInstance(out_aval, core.ShapedArray)
@ -218,7 +218,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects, {state.StateEffect})
self.assertSetEqual(jaxpr.effects, {state.AccumEffect(ref_aval)})
self.assertLen(out_avals, 0)
def test_addupdate_abstract_eval_must_take_in_refs(self):