mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
24517ca3e0
commit
46a516275f
@ -1929,6 +1929,14 @@ def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
mutable_array_p = Primitive('mutable_array')
|
||||
|
||||
class InternalMutableArray(effects.Effect):
|
||||
pass
|
||||
|
||||
@mutable_array_p.def_effectful_abstract_eval
|
||||
def mutable_array_abstract_eval(init_aval):
|
||||
from jax._src.state.types import AbstractRef # type: ignore[import]
|
||||
return AbstractRef(init_aval), {InternalMutableArray}
|
||||
|
||||
@mutable_array_p.def_impl
|
||||
def _mutable_array_impl(init_val):
|
||||
from jax._src.state.types import AbstractRef # type: ignore[import]
|
||||
@ -2922,6 +2930,8 @@ def _check_jaxpr(
|
||||
write(v, v.aval)
|
||||
|
||||
# Check each eqn.
|
||||
sentinel = object()
|
||||
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
|
||||
for eqn_idx, eqn in enumerate(jaxpr.eqns):
|
||||
prim = eqn.primitive
|
||||
try:
|
||||
@ -2943,6 +2953,9 @@ def _check_jaxpr(
|
||||
|
||||
# Check the computed effect type matches the eqn's annotation, and is
|
||||
# included in the jaxpr's annotation.
|
||||
if prim is mutable_array_p:
|
||||
outvar, = eqn.outvars
|
||||
in_idx[outvar] = None # type: ignore
|
||||
if eqn.effects != eqn_effects:
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects. "
|
||||
f"Equation effects: {eqn.effects}. "
|
||||
@ -2950,11 +2963,9 @@ def _check_jaxpr(
|
||||
for eff in eqn.effects:
|
||||
if isinstance(eff, effects.JaxprInputEffect):
|
||||
eqn_invar = eqn.invars[eff.input_index]
|
||||
all_vars = [*jaxpr.constvars, *jaxpr.invars]
|
||||
if eqn_invar not in all_vars:
|
||||
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
|
||||
raise JaxprTypeError(
|
||||
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
|
||||
jaxpr_index = all_vars.index(eqn_invar)
|
||||
jaxpr_effect = eff.replace(input_index=jaxpr_index)
|
||||
if jaxpr_effect not in jaxpr.effects:
|
||||
raise JaxprTypeError(
|
||||
|
@ -1726,9 +1726,13 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
sentinel = object()
|
||||
jaxpr_effects = set()
|
||||
all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))}
|
||||
for eqn in eqns:
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
outvar, = eqn.outvars
|
||||
all_vars[outvar] = None # type: ignore
|
||||
for eff in eqn.effects:
|
||||
if isinstance(eff, effects.JaxprInputEffect):
|
||||
if eff.input_index >= len(eqn.invars):
|
||||
@ -1738,7 +1742,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
"\n Jaxpr: "
|
||||
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
|
||||
invar = eqn.invars[eff.input_index]
|
||||
if (input_index := all_vars.get(invar)) is None:
|
||||
if (input_index := all_vars.get(invar, sentinel)) is sentinel:
|
||||
raise ValueError(
|
||||
f"`JaxprInputEffect` {eff} does not have "
|
||||
f"corresponding input: {invar}."
|
||||
@ -2735,13 +2739,6 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
|
||||
return prim.bind(*subfuns, *args, **bind_params)
|
||||
|
||||
|
||||
def _error_staging_mutable_array_p(trace, x):
|
||||
raise Exception(
|
||||
"mutable_array constructor can't be staged out, and in particular can't "
|
||||
"be used under a jax.jit or jax.lax.scan")
|
||||
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p
|
||||
|
||||
|
||||
# TODO(mattjj): the following are deprecated; update callers to _nounits version
|
||||
# See https://github.com/google/jax/pull/9498
|
||||
@lu.transformation
|
||||
|
@ -1817,6 +1817,13 @@ def _move_mutable_consts(
|
||||
effects, None)
|
||||
return core.ClosedJaxpr(jaxpr, consts), in_mut
|
||||
|
||||
@weakref_lru_cache
|
||||
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
|
||||
from jax._src.state.discharge import discharge_state
|
||||
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
|
||||
jaxpr_._debug_info = jaxpr.jaxpr.debug_info
|
||||
return core.ClosedJaxpr(jaxpr_, consts)
|
||||
|
||||
|
||||
class SemanticallyEqualShardings:
|
||||
|
||||
@ -2074,6 +2081,8 @@ def lower_sharding_computation(
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
else:
|
||||
inout_aliases = mut = None
|
||||
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
|
||||
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
|
@ -96,9 +96,6 @@ def register_discharge_rule(prim: core.Primitive):
|
||||
_discharge_rules[prim] = f
|
||||
return register
|
||||
|
||||
def _has_refs(eqn: core.JaxprEqn):
|
||||
return any(isinstance(v.aval, AbstractRef) for v in eqn.invars)
|
||||
|
||||
def _eval_jaxpr_discharge_state(
|
||||
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
|
||||
*args: Any):
|
||||
@ -113,8 +110,12 @@ def _eval_jaxpr_discharge_state(
|
||||
if d and isinstance(v.aval, AbstractRef)}
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge
|
||||
for v in eqn.invars):
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
[invar], [outvar] = eqn.invars, eqn.outvars
|
||||
init_val = env.read(invar)
|
||||
env.write(outvar, init_val)
|
||||
refs_to_discharge.add(id(outvar.aval))
|
||||
elif any(id(v.aval) in refs_to_discharge for v in eqn.invars):
|
||||
if eqn.primitive not in _discharge_rules:
|
||||
raise NotImplementedError("No state discharge rule implemented for "
|
||||
f"primitive: {eqn.primitive}")
|
||||
|
@ -1588,6 +1588,21 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(w, 10, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_internal_mutarray_basic(self, jit):
|
||||
def f():
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
x_mut[0] += 1
|
||||
x_mut[0] += 1
|
||||
x_mut[2] += 1
|
||||
return x_mut[...]
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f()
|
||||
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
|
||||
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user