Merge pull request #11618 from mattjj:scan-partial-eval-custom-fix

PiperOrigin-RevId: 463499406
This commit is contained in:
jax authors 2022-07-26 21:46:02 -07:00
commit 016c6df65e
2 changed files with 25 additions and 4 deletions

View File

@ -857,7 +857,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
num_const_known = len(const_uk) - sum(const_uk)
num_carry_known = len(carry_uk) - sum(carry_uk)
num_xs_known = len( xs_uk) - sum( xs_uk)
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, _ = \
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \
pe.partial_eval_jaxpr_nounits(
jaxpr_known,
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
@ -868,7 +868,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr_staged = pe.move_binders_to_front(
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
del loop_dep, num_carry_known, num_xs_known
del loop_dep, num_carry_known, num_xs_known, const_uk
# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
@ -882,9 +882,13 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
# jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist
# (corresponding to consts_known_lp_avals) followed by known carry and xs.
linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
_, linear_known_ = split_list(linear_known_, [num_const_known])
linear_known = [False] * len(consts_known_lp_avals) + linear_known_
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
num_consts=len(const_uk)-sum(const_uk),
num_consts=len(consts_known_lp_avals),
num_carry=len(carry_uk)-sum(carry_uk),
linear=tuple(linear_known))

View File

@ -2521,6 +2521,23 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return lax.cond(x < 0., lambda x: x, lambda x: x, x)
jax.vmap(jax.jacrev(lambda x: cond_id(cond_id(x))))(jnp.ones(1))
@parameterized.named_parameters(
{"testcase_name": "impl={}".format(scan_name), "scan": scan_impl}
for scan_impl, scan_name in SCAN_IMPLS)
def test_scan_hoisting_consts(self, scan):
A = jnp.arange(4.).reshape(2, 2)
B = jnp.arange(4.).reshape(2, 2) + 1.
def f(x):
def body(c, _):
c1, c2, c3 = c
return (jnp.dot(A, c1), jnp.dot(B, c2), jnp.dot(jnp.sin(B), c3)), None
init_carry = (x * jnp.ones(2), x * jnp.ones(2), x * jnp.ones(2))
(c1, c2, c3), _ = scan(body, init_carry, None, length=3)
return jnp.sum(c1) + jnp.sum(c2) + jnp.sum(c3)
jax.grad(f)(1.) # doesn't crash
class ForLoopTest(jtu.JaxTestCase):