From 46a516275f4910845426e8b1dfa6640cfd0940ce Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 2 Apr 2024 10:30:35 -0700 Subject: [PATCH] [mutable-arrays] enable refs without cps, and not just at top level Co-authored-by: Dougal Maclaurin --- jax/_src/core.py | 17 ++++++++++++++--- jax/_src/interpreters/partial_eval.py | 13 +++++-------- jax/_src/interpreters/pxla.py | 9 +++++++++ jax/_src/state/discharge.py | 11 ++++++----- tests/state_test.py | 15 +++++++++++++++ 5 files changed, 49 insertions(+), 16 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 868806e8d..d4a64ad6d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1929,6 +1929,14 @@ def mutable_array(init_val): return mutable_array_p.bind(init_val) mutable_array_p = Primitive('mutable_array') +class InternalMutableArray(effects.Effect): + pass + +@mutable_array_p.def_effectful_abstract_eval +def mutable_array_abstract_eval(init_aval): + from jax._src.state.types import AbstractRef # type: ignore[import] + return AbstractRef(init_aval), {InternalMutableArray} + @mutable_array_p.def_impl def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # type: ignore[import] @@ -2922,6 +2930,8 @@ def _check_jaxpr( write(v, v.aval) # Check each eqn. + sentinel = object() + in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} for eqn_idx, eqn in enumerate(jaxpr.eqns): prim = eqn.primitive try: @@ -2943,6 +2953,9 @@ def _check_jaxpr( # Check the computed effect type matches the eqn's annotation, and is # included in the jaxpr's annotation. + if prim is mutable_array_p: + outvar, = eqn.outvars + in_idx[outvar] = None # type: ignore if eqn.effects != eqn_effects: raise JaxprTypeError("Inferred effects do not match equation effects. " f"Equation effects: {eqn.effects}. " @@ -2950,11 +2963,9 @@ def _check_jaxpr( for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): eqn_invar = eqn.invars[eff.input_index] - all_vars = [*jaxpr.constvars, *jaxpr.invars] - if eqn_invar not in all_vars: + if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel: raise JaxprTypeError( "Invalid `JaxprInputEffect`: must correspond to a jaxpr invar") - jaxpr_index = all_vars.index(eqn_invar) jaxpr_effect = eff.replace(input_index=jaxpr_index) if jaxpr_effect not in jaxpr.effects: raise JaxprTypeError( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 4824da917..b3ec08ca0 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1726,9 +1726,13 @@ class DynamicJaxprTracer(core.Tracer): api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: + sentinel = object() jaxpr_effects = set() all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))} for eqn in eqns: + if eqn.primitive is core.mutable_array_p: + outvar, = eqn.outvars + all_vars[outvar] = None # type: ignore for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): if eff.input_index >= len(eqn.invars): @@ -1738,7 +1742,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") invar = eqn.invars[eff.input_index] - if (input_index := all_vars.get(invar)) is None: + if (input_index := all_vars.get(invar, sentinel)) is sentinel: raise ValueError( f"`JaxprInputEffect` {eff} does not have " f"corresponding input: {invar}." @@ -2735,13 +2739,6 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): return prim.bind(*subfuns, *args, **bind_params) -def _error_staging_mutable_array_p(trace, x): - raise Exception( - "mutable_array constructor can't be staged out, and in particular can't " - "be used under a jax.jit or jax.lax.scan") -custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p - - # TODO(mattjj): the following are deprecated; update callers to _nounits version # See https://github.com/google/jax/pull/9498 @lu.transformation diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 5dc21a61d..73b855a06 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1817,6 +1817,13 @@ def _move_mutable_consts( effects, None) return core.ClosedJaxpr(jaxpr, consts), in_mut +@weakref_lru_cache +def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: + from jax._src.state.discharge import discharge_state + jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts) + jaxpr_._debug_info = jaxpr.jaxpr.debug_info + return core.ClosedJaxpr(jaxpr_, consts) + class SemanticallyEqualShardings: @@ -2074,6 +2081,8 @@ def lower_sharding_computation( global_out_avals = closed_jaxpr.out_avals else: inout_aliases = mut = None + if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects): + closed_jaxpr = _discharge_internal_refs(closed_jaxpr) jaxpr = closed_jaxpr.jaxpr assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 9cd9a9b86..12d879c8b 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -96,9 +96,6 @@ def register_discharge_rule(prim: core.Primitive): _discharge_rules[prim] = f return register -def _has_refs(eqn: core.JaxprEqn): - return any(isinstance(v.aval, AbstractRef) for v in eqn.invars) - def _eval_jaxpr_discharge_state( jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any], *args: Any): @@ -113,8 +110,12 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: - if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge - for v in eqn.invars): + if eqn.primitive is core.mutable_array_p: + [invar], [outvar] = eqn.invars, eqn.outvars + init_val = env.read(invar) + env.write(outvar, init_val) + refs_to_discharge.add(id(outvar.aval)) + elif any(id(v.aval) in refs_to_discharge for v in eqn.invars): if eqn.primitive not in _discharge_rules: raise NotImplementedError("No state discharge rule implemented for " f"primitive: {eqn.primitive}") diff --git a/tests/state_test.py b/tests/state_test.py index b396ee7df..5913e0ec9 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1588,6 +1588,21 @@ class MutableArrayTest(jtu.JaxTestCase): check_dtypes=False) self.assertAllClose(w, 10, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_internal_mutarray_basic(self, jit): + def f(): + x_mut = core.mutable_array(jnp.zeros(3)) + x_mut[0] += 1 + x_mut[0] += 1 + x_mut[2] += 1 + return x_mut[...] + + if jit: + f = jax.jit(f) + + out = f() + self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False) + if CAN_USE_HYPOTHESIS: