Fix mutable array effects not being tracked properly

PiperOrigin-RevId: 680801564
This commit is contained in:
Sharad Vikram 2024-09-30 18:54:36 -07:00 committed by jax authors
parent 31cb3fd36e
commit 80f963c003
3 changed files with 17 additions and 0 deletions

View File

@ -2906,6 +2906,7 @@ def _check_jaxpr(
# Check each eqn.
sentinel = object()
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
mut_arrays = set()
for eqn_idx, eqn in enumerate(jaxpr.eqns):
prim = eqn.primitive
try:
@ -2930,6 +2931,7 @@ def _check_jaxpr(
if prim is mutable_array_p:
outvar, = eqn.outvars
in_idx[outvar] = None # type: ignore
mut_arrays.add(outvar)
if eqn.effects != eqn_effects:
raise JaxprTypeError("Inferred effects do not match equation effects. "
f"Equation effects: {eqn.effects}. "
@ -2937,6 +2939,8 @@ def _check_jaxpr(
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
eqn_invar = eqn.invars[eff.input_index]
if eqn_invar in mut_arrays:
continue
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")

View File

@ -1668,10 +1668,12 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()
jaxpr_effects = set()
all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))}
mut_arrays = set()
for eqn in eqns:
if eqn.primitive is core.mutable_array_p:
outvar, = eqn.outvars
all_vars[outvar] = None # type: ignore
mut_arrays.add(outvar)
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
if eff.input_index >= len(eqn.invars):
@ -1681,6 +1683,8 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
"\n Jaxpr: "
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
invar = eqn.invars[eff.input_index]
if invar in mut_arrays:
continue
if (input_index := all_vars.get(invar, sentinel)) is sentinel:
raise ValueError(
f"`JaxprInputEffect` {eff} does not have "

View File

@ -223,5 +223,14 @@ class MutableArrayTest(jtu.JaxTestCase):
_, xs = doit()
self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False)
def test_double_jit_mutable_array(self):
@jax.jit
@jax.jit
def f():
x_ref = core.mutable_array(jnp.zeros(8))
return x_ref[...]
x = f()
self.assertArraysEqual(x, jnp.zeros(8))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())