mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
This commit is contained in:
parent
31cb3fd36e
commit
80f963c003
@ -2906,6 +2906,7 @@ def _check_jaxpr(
|
||||
# Check each eqn.
|
||||
sentinel = object()
|
||||
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
|
||||
mut_arrays = set()
|
||||
for eqn_idx, eqn in enumerate(jaxpr.eqns):
|
||||
prim = eqn.primitive
|
||||
try:
|
||||
@ -2930,6 +2931,7 @@ def _check_jaxpr(
|
||||
if prim is mutable_array_p:
|
||||
outvar, = eqn.outvars
|
||||
in_idx[outvar] = None # type: ignore
|
||||
mut_arrays.add(outvar)
|
||||
if eqn.effects != eqn_effects:
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects. "
|
||||
f"Equation effects: {eqn.effects}. "
|
||||
@ -2937,6 +2939,8 @@ def _check_jaxpr(
|
||||
for eff in eqn.effects:
|
||||
if isinstance(eff, effects.JaxprInputEffect):
|
||||
eqn_invar = eqn.invars[eff.input_index]
|
||||
if eqn_invar in mut_arrays:
|
||||
continue
|
||||
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
|
||||
raise JaxprTypeError(
|
||||
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
|
||||
|
@ -1668,10 +1668,12 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
sentinel = object()
|
||||
jaxpr_effects = set()
|
||||
all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))}
|
||||
mut_arrays = set()
|
||||
for eqn in eqns:
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
outvar, = eqn.outvars
|
||||
all_vars[outvar] = None # type: ignore
|
||||
mut_arrays.add(outvar)
|
||||
for eff in eqn.effects:
|
||||
if isinstance(eff, effects.JaxprInputEffect):
|
||||
if eff.input_index >= len(eqn.invars):
|
||||
@ -1681,6 +1683,8 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
"\n Jaxpr: "
|
||||
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
|
||||
invar = eqn.invars[eff.input_index]
|
||||
if invar in mut_arrays:
|
||||
continue
|
||||
if (input_index := all_vars.get(invar, sentinel)) is sentinel:
|
||||
raise ValueError(
|
||||
f"`JaxprInputEffect` {eff} does not have "
|
||||
|
@ -223,5 +223,14 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
_, xs = doit()
|
||||
self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False)
|
||||
|
||||
def test_double_jit_mutable_array(self):
|
||||
@jax.jit
|
||||
@jax.jit
|
||||
def f():
|
||||
x_ref = core.mutable_array(jnp.zeros(8))
|
||||
return x_ref[...]
|
||||
x = f()
|
||||
self.assertArraysEqual(x, jnp.zeros(8))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user