mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split State effect into Read/Write/Accum effects and tie them to Ref avals
This commit is contained in:
parent
b3393e3b60
commit
b6c3b9df19
@ -51,7 +51,9 @@ T = TypeVar('T')
|
||||
class Ref(Generic[T]): pass
|
||||
Array = Any
|
||||
|
||||
StateEffect = state.StateEffect
|
||||
ReadEffect = state.ReadEffect
|
||||
WriteEffect = state.WriteEffect
|
||||
AccumEffect = state.AccumEffect
|
||||
ShapedArrayRef = state.ShapedArrayRef
|
||||
ref_set = state.ref_set
|
||||
ref_get = state.ref_get
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for state."""
|
||||
from jax._src.state.types import ShapedArrayRef, StateEffect
|
||||
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
|
||||
ref_addupdate, get_p, swap_p,
|
||||
addupdate_p)
|
||||
|
@ -27,7 +27,8 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.state.types import ShapedArrayRef, StateEffect
|
||||
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
|
||||
## General utilities
|
||||
|
||||
@ -155,7 +156,7 @@ def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx, indexed_dims):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
|
||||
return (core.ShapedArray(shape, ref_aval.dtype), {StateEffect})
|
||||
return (core.ShapedArray(shape, ref_aval.dtype), {ReadEffect(ref_aval)})
|
||||
get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||||
|
||||
|
||||
@ -182,7 +183,7 @@ def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
return (core.ShapedArray(expected_output_shape, ref_aval.dtype),
|
||||
{StateEffect})
|
||||
{WriteEffect(ref_aval)})
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
|
||||
@ -209,7 +210,7 @@ def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
|
||||
raise ValueError("Invalid dtype for `addupdate`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
return [], {StateEffect}
|
||||
return [], {AccumEffect(ref_aval)}
|
||||
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||||
|
||||
## Pretty printing for `get` and `swap` in jaxprs
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for state types."""
|
||||
from __future__ import annotations
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
@ -44,11 +45,29 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Array = Any
|
||||
|
||||
class _StateEffect:
|
||||
def __repr__(self):
|
||||
return "State"
|
||||
__str__ = __repr__
|
||||
StateEffect = _StateEffect()
|
||||
class RefEffect:
|
||||
def __init__(self, ref_aval: ShapedArrayRef):
|
||||
self.ref_aval = ref_aval
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return self.ref_aval is other.ref_aval
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.__class__, self.ref_aval))
|
||||
|
||||
class ReadEffect(RefEffect):
|
||||
def __str__(self):
|
||||
return f"Read<{self.ref_aval}>"
|
||||
|
||||
class WriteEffect(RefEffect):
|
||||
def __str__(self):
|
||||
return f"Write<{self.ref_aval}>"
|
||||
|
||||
class AccumEffect(RefEffect):
|
||||
def __str__(self):
|
||||
return f"Accum<{self.ref_aval}>"
|
||||
|
||||
# ## `Ref`s
|
||||
|
||||
|
@ -317,9 +317,9 @@ jax_test(
|
||||
name = "lax_control_flow_test",
|
||||
srcs = ["lax_control_flow_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 20,
|
||||
"gpu": 20,
|
||||
"tpu": 20,
|
||||
"cpu": 30,
|
||||
"gpu": 30,
|
||||
"tpu": 30,
|
||||
"iree": 10,
|
||||
},
|
||||
)
|
||||
|
@ -97,7 +97,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
else:
|
||||
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval])
|
||||
self.assertSetEqual(jaxpr.effects, {state.StateEffect})
|
||||
self.assertSetEqual(jaxpr.effects, {state.ReadEffect(ref_aval)})
|
||||
self.assertLen(out_avals, 1)
|
||||
out_aval, = out_avals
|
||||
self.assertIsInstance(out_aval, core.ShapedArray)
|
||||
@ -163,7 +163,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
else:
|
||||
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval])
|
||||
self.assertSetEqual(jaxpr.effects, {state.StateEffect})
|
||||
self.assertSetEqual(jaxpr.effects, {state.WriteEffect(ref_aval)})
|
||||
self.assertLen(out_avals, 1)
|
||||
out_aval, = out_avals
|
||||
self.assertIsInstance(out_aval, core.ShapedArray)
|
||||
@ -218,7 +218,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
else:
|
||||
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval])
|
||||
self.assertSetEqual(jaxpr.effects, {state.StateEffect})
|
||||
self.assertSetEqual(jaxpr.effects, {state.AccumEffect(ref_aval)})
|
||||
self.assertLen(out_avals, 0)
|
||||
|
||||
def test_addupdate_abstract_eval_must_take_in_refs(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user