[mutable-arrays] enable refs without cps, and not just at top level

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2024-04-02 10:30:35 -07:00
parent 24517ca3e0
commit 46a516275f
5 changed files with 49 additions and 16 deletions

View File

@ -1929,6 +1929,14 @@ def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = Primitive('mutable_array')
class InternalMutableArray(effects.Effect):
pass
@mutable_array_p.def_effectful_abstract_eval
def mutable_array_abstract_eval(init_aval):
from jax._src.state.types import AbstractRef # type: ignore[import]
return AbstractRef(init_aval), {InternalMutableArray}
@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # type: ignore[import]
@ -2922,6 +2930,8 @@ def _check_jaxpr(
write(v, v.aval)
# Check each eqn.
sentinel = object()
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
for eqn_idx, eqn in enumerate(jaxpr.eqns):
prim = eqn.primitive
try:
@ -2943,6 +2953,9 @@ def _check_jaxpr(
# Check the computed effect type matches the eqn's annotation, and is
# included in the jaxpr's annotation.
if prim is mutable_array_p:
outvar, = eqn.outvars
in_idx[outvar] = None # type: ignore
if eqn.effects != eqn_effects:
raise JaxprTypeError("Inferred effects do not match equation effects. "
f"Equation effects: {eqn.effects}. "
@ -2950,11 +2963,9 @@ def _check_jaxpr(
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
eqn_invar = eqn.invars[eff.input_index]
all_vars = [*jaxpr.constvars, *jaxpr.invars]
if eqn_invar not in all_vars:
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
jaxpr_index = all_vars.index(eqn_invar)
jaxpr_effect = eff.replace(input_index=jaxpr_index)
if jaxpr_effect not in jaxpr.effects:
raise JaxprTypeError(

View File

@ -1726,9 +1726,13 @@ class DynamicJaxprTracer(core.Tracer):
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()
jaxpr_effects = set()
all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))}
for eqn in eqns:
if eqn.primitive is core.mutable_array_p:
outvar, = eqn.outvars
all_vars[outvar] = None # type: ignore
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
if eff.input_index >= len(eqn.invars):
@ -1738,7 +1742,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
"\n Jaxpr: "
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
invar = eqn.invars[eff.input_index]
if (input_index := all_vars.get(invar)) is None:
if (input_index := all_vars.get(invar, sentinel)) is sentinel:
raise ValueError(
f"`JaxprInputEffect` {eff} does not have "
f"corresponding input: {invar}."
@ -2735,13 +2739,6 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
return prim.bind(*subfuns, *args, **bind_params)
def _error_staging_mutable_array_p(trace, x):
raise Exception(
"mutable_array constructor can't be staged out, and in particular can't "
"be used under a jax.jit or jax.lax.scan")
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p
# TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498
@lu.transformation

View File

@ -1817,6 +1817,13 @@ def _move_mutable_consts(
effects, None)
return core.ClosedJaxpr(jaxpr, consts), in_mut
@weakref_lru_cache
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
from jax._src.state.discharge import discharge_state
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
jaxpr_._debug_info = jaxpr.jaxpr.debug_info
return core.ClosedJaxpr(jaxpr_, consts)
class SemanticallyEqualShardings:
@ -2074,6 +2081,8 @@ def lower_sharding_computation(
global_out_avals = closed_jaxpr.out_avals
else:
inout_aliases = mut = None
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
jaxpr = closed_jaxpr.jaxpr
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (

View File

@ -96,9 +96,6 @@ def register_discharge_rule(prim: core.Primitive):
_discharge_rules[prim] = f
return register
def _has_refs(eqn: core.JaxprEqn):
return any(isinstance(v.aval, AbstractRef) for v in eqn.invars)
def _eval_jaxpr_discharge_state(
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
*args: Any):
@ -113,8 +110,12 @@ def _eval_jaxpr_discharge_state(
if d and isinstance(v.aval, AbstractRef)}
for eqn in jaxpr.eqns:
if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge
for v in eqn.invars):
if eqn.primitive is core.mutable_array_p:
[invar], [outvar] = eqn.invars, eqn.outvars
init_val = env.read(invar)
env.write(outvar, init_val)
refs_to_discharge.add(id(outvar.aval))
elif any(id(v.aval) in refs_to_discharge for v in eqn.invars):
if eqn.primitive not in _discharge_rules:
raise NotImplementedError("No state discharge rule implemented for "
f"primitive: {eqn.primitive}")

View File

@ -1588,6 +1588,21 @@ class MutableArrayTest(jtu.JaxTestCase):
check_dtypes=False)
self.assertAllClose(w, 10, check_dtypes=False)
@parameterized.parameters([True, False])
def test_internal_mutarray_basic(self, jit):
def f():
x_mut = core.mutable_array(jnp.zeros(3))
x_mut[0] += 1
x_mut[0] += 1
x_mut[2] += 1
return x_mut[...]
if jit:
f = jax.jit(f)
out = f()
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
if CAN_USE_HYPOTHESIS: