diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 823d7c7e2..f57a857ec 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1049,8 +1049,8 @@ def _scan_pp_rule(eqn, context, settings): del printed_params['reverse'] return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -def _scan_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, - num_carry, linear, unroll, reverse, length): +def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, + num_carry, linear, unroll, reverse, length): jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError consts, carry, xs = split_list(args, [num_consts, num_carry]) @@ -1129,7 +1129,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule -state_discharge.register_discharge_rule(scan_p)(_scan_discharge_rule) +state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule) # TODO(mattjj,frostig): un-comment this pp rule # core.pp_eqn_rules[scan_p] = _scan_pp_rule diff --git a/tests/state_test.py b/tests/state_test.py index e964b1d5f..b52bea429 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1205,6 +1205,37 @@ class StateControlFlowTest(jtu.JaxTestCase): self.assertAllClose(jax.jit(f)(0, 1, 5, zs), (35, 11)) self.assertAllClose(jax.jit(f)(1, 1, 2, zs), (21, 11)) + def test_scan_with_state_in_body_nested(self): + @run_state + def g(refs): + a_ref, x_ref, w_ref, y_ref, zs_ref = refs + + def f(x, w, y, zs): + @run_state + def loop(refs): + x_ref, w_ref = refs + def body(y, z): + x_ref[...] += y + w_ref[...] += z + a_ref[...] += z + return y + 1, () + lax.scan(body, y, zs) + a_ref[...] += 1 + out = loop((x, w)) + a_ref[...] += 1 + return out + + x, w = f(x_ref[...], w_ref[...], y_ref[...], zs_ref[...]) + x_ref[...] = x + w_ref[...] = w + + zs = jnp.arange(5) + jaxpr = jax.make_jaxpr(g)((1, 0, 1, 5, zs)).jaxpr + self.assertEmpty(jaxpr.effects) + self.assertAllClose(jax.jit(g)((1, 0, 1, 5, zs))[:3], (13, 35, 11)) + self.assertAllClose(jax.jit(g)((1, 1, 1, 2, zs))[:3], (13, 21, 11)) + + class GeneralRefTest(jtu.JaxTestCase): def test_unshaped_ref(self):