[run_state] add scan nested test, tweak rule name to mention 'state'

This commit is contained in:
Matthew Johnson 2023-10-04 11:48:42 -07:00
parent 923498fb45
commit 29af93b4cb
2 changed files with 34 additions and 3 deletions

View File

@ -1049,8 +1049,8 @@ def _scan_pp_rule(eqn, context, settings):
del printed_params['reverse']
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
def _scan_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
num_carry, linear, unroll, reverse, length):
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
num_carry, linear, unroll, reverse, length):
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
if consts: raise NotImplementedError
consts, carry, xs = split_list(args, [num_consts, num_carry])
@ -1129,7 +1129,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
state_discharge.register_discharge_rule(scan_p)(_scan_discharge_rule)
state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule)
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule

View File

@ -1205,6 +1205,37 @@ class StateControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(jax.jit(f)(0, 1, 5, zs), (35, 11))
self.assertAllClose(jax.jit(f)(1, 1, 2, zs), (21, 11))
def test_scan_with_state_in_body_nested(self):
@run_state
def g(refs):
a_ref, x_ref, w_ref, y_ref, zs_ref = refs
def f(x, w, y, zs):
@run_state
def loop(refs):
x_ref, w_ref = refs
def body(y, z):
x_ref[...] += y
w_ref[...] += z
a_ref[...] += z
return y + 1, ()
lax.scan(body, y, zs)
a_ref[...] += 1
out = loop((x, w))
a_ref[...] += 1
return out
x, w = f(x_ref[...], w_ref[...], y_ref[...], zs_ref[...])
x_ref[...] = x
w_ref[...] = w
zs = jnp.arange(5)
jaxpr = jax.make_jaxpr(g)((1, 0, 1, 5, zs)).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(g)((1, 0, 1, 5, zs))[:3], (13, 35, 11))
self.assertAllClose(jax.jit(g)((1, 1, 1, 2, zs))[:3], (13, 21, 11))
class GeneralRefTest(jtu.JaxTestCase):
def test_unshaped_ref(self):