mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[run_state] add scan nested test, tweak rule name to mention 'state'
This commit is contained in:
parent
923498fb45
commit
29af93b4cb
@ -1049,8 +1049,8 @@ def _scan_pp_rule(eqn, context, settings):
|
||||
del printed_params['reverse']
|
||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
||||
|
||||
def _scan_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
|
||||
num_carry, linear, unroll, reverse, length):
|
||||
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
|
||||
num_carry, linear, unroll, reverse, length):
|
||||
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
||||
if consts: raise NotImplementedError
|
||||
consts, carry, xs = split_list(args, [num_consts, num_carry])
|
||||
@ -1129,7 +1129,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
state_discharge.register_discharge_rule(scan_p)(_scan_discharge_rule)
|
||||
state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule)
|
||||
# TODO(mattjj,frostig): un-comment this pp rule
|
||||
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
|
||||
|
||||
|
@ -1205,6 +1205,37 @@ class StateControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.jit(f)(0, 1, 5, zs), (35, 11))
|
||||
self.assertAllClose(jax.jit(f)(1, 1, 2, zs), (21, 11))
|
||||
|
||||
def test_scan_with_state_in_body_nested(self):
|
||||
@run_state
|
||||
def g(refs):
|
||||
a_ref, x_ref, w_ref, y_ref, zs_ref = refs
|
||||
|
||||
def f(x, w, y, zs):
|
||||
@run_state
|
||||
def loop(refs):
|
||||
x_ref, w_ref = refs
|
||||
def body(y, z):
|
||||
x_ref[...] += y
|
||||
w_ref[...] += z
|
||||
a_ref[...] += z
|
||||
return y + 1, ()
|
||||
lax.scan(body, y, zs)
|
||||
a_ref[...] += 1
|
||||
out = loop((x, w))
|
||||
a_ref[...] += 1
|
||||
return out
|
||||
|
||||
x, w = f(x_ref[...], w_ref[...], y_ref[...], zs_ref[...])
|
||||
x_ref[...] = x
|
||||
w_ref[...] = w
|
||||
|
||||
zs = jnp.arange(5)
|
||||
jaxpr = jax.make_jaxpr(g)((1, 0, 1, 5, zs)).jaxpr
|
||||
self.assertEmpty(jaxpr.effects)
|
||||
self.assertAllClose(jax.jit(g)((1, 0, 1, 5, zs))[:3], (13, 35, 11))
|
||||
self.assertAllClose(jax.jit(g)((1, 1, 1, 2, zs))[:3], (13, 21, 11))
|
||||
|
||||
|
||||
class GeneralRefTest(jtu.JaxTestCase):
|
||||
|
||||
def test_unshaped_ref(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user