mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11618 from mattjj:scan-partial-eval-custom-fix
PiperOrigin-RevId: 463499406
This commit is contained in:
commit
016c6df65e
@ -857,7 +857,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
num_const_known = len(const_uk) - sum(const_uk)
|
||||
num_carry_known = len(carry_uk) - sum(carry_uk)
|
||||
num_xs_known = len( xs_uk) - sum( xs_uk)
|
||||
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, _ = \
|
||||
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \
|
||||
pe.partial_eval_jaxpr_nounits(
|
||||
jaxpr_known,
|
||||
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
|
||||
@ -868,7 +868,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
jaxpr_staged = pe.move_binders_to_front(
|
||||
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
|
||||
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
|
||||
del loop_dep, num_carry_known, num_xs_known
|
||||
del loop_dep, num_carry_known, num_xs_known, const_uk
|
||||
|
||||
# Create residual variables.
|
||||
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
|
||||
@ -882,9 +882,13 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
|
||||
# jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist
|
||||
# (corresponding to consts_known_lp_avals) followed by known carry and xs.
|
||||
linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
|
||||
_, linear_known_ = split_list(linear_known_, [num_const_known])
|
||||
linear_known = [False] * len(consts_known_lp_avals) + linear_known_
|
||||
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
|
||||
num_consts=len(const_uk)-sum(const_uk),
|
||||
num_consts=len(consts_known_lp_avals),
|
||||
num_carry=len(carry_uk)-sum(carry_uk),
|
||||
linear=tuple(linear_known))
|
||||
|
||||
|
@ -2521,6 +2521,23 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return lax.cond(x < 0., lambda x: x, lambda x: x, x)
|
||||
jax.vmap(jax.jacrev(lambda x: cond_id(cond_id(x))))(jnp.ones(1))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "impl={}".format(scan_name), "scan": scan_impl}
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def test_scan_hoisting_consts(self, scan):
|
||||
A = jnp.arange(4.).reshape(2, 2)
|
||||
B = jnp.arange(4.).reshape(2, 2) + 1.
|
||||
|
||||
def f(x):
|
||||
def body(c, _):
|
||||
c1, c2, c3 = c
|
||||
return (jnp.dot(A, c1), jnp.dot(B, c2), jnp.dot(jnp.sin(B), c3)), None
|
||||
init_carry = (x * jnp.ones(2), x * jnp.ones(2), x * jnp.ones(2))
|
||||
(c1, c2, c3), _ = scan(body, init_carry, None, length=3)
|
||||
return jnp.sum(c1) + jnp.sum(c2) + jnp.sum(c3)
|
||||
|
||||
jax.grad(f)(1.) # doesn't crash
|
||||
|
||||
|
||||
class ForLoopTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user