mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11656 from mattjj:while-loop-new-remat
PiperOrigin-RevId: 463973829
This commit is contained in:
commit
a636bd3468
@ -446,7 +446,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
unks_out: List[bool] = [False] * len(eqn.outvars)
|
||||
for jaxpr in branches:
|
||||
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
||||
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
|
||||
unks_out = map(operator.or_, unks_out, unks_out_)
|
||||
|
||||
@ -458,7 +458,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
for jaxpr in branches:
|
||||
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
||||
ensure_out_unknowns=unks_out, ensure_out_inst=True,
|
||||
saveable=saveable)
|
||||
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
|
||||
@ -481,7 +481,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
# passing in_inst argument to partial_eval_jaxpr_custom above).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
inst_in = [True] * len(inst_in)
|
||||
del inst_in
|
||||
|
||||
# Create residual variables.
|
||||
newvar = core.gensym()
|
||||
|
@ -828,14 +828,14 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
unks_in = const_uk + carry_uk + xs_uk
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
|
||||
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True,
|
||||
ensure_out_unknowns=carry_uk + [False] * num_ys,
|
||||
ensure_out_inst=True, saveable=saveable)
|
||||
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
|
||||
if carry_uk_out == carry_uk:
|
||||
break
|
||||
else:
|
||||
carry_uk = _map(operator.or_, carry_uk , carry_uk_out )
|
||||
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
|
||||
@ -1309,6 +1309,70 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk]
|
||||
return util.merge_lists(carry_uk, out_known, out_tracers)
|
||||
|
||||
# TODO(mattjj): de-duplicate code with _while_partial_eval
|
||||
def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
del saveable # We can't save any residuals anyway (w/o dynamic shapes)!
|
||||
cond_jaxpr = eqn.params['cond_jaxpr']
|
||||
cond_nconsts = eqn.params['cond_nconsts']
|
||||
body_jaxpr = eqn.params['body_jaxpr']
|
||||
body_nconsts = eqn.params['body_nconsts']
|
||||
|
||||
cond_consts_uk, body_consts_uk, carry_init_uk = \
|
||||
split_list(unks_in, [cond_nconsts, body_nconsts])
|
||||
|
||||
# Fixpoint to compute known part of the body (trivial on 'inst_in', since we
|
||||
# make all inputs available as DCE can subsequently prune any unused ones)
|
||||
carry_uk = carry_init_uk
|
||||
for _ in range(1 + len(carry_uk)):
|
||||
body_unks_in = body_consts_uk + carry_uk
|
||||
jaxpr_known_, _, carry_uk_out, _, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
body_jaxpr.jaxpr, in_unknowns=body_unks_in, in_inst=True,
|
||||
ensure_out_unknowns=carry_uk, ensure_out_inst=True,
|
||||
saveable=ad_checkpoint.nothing_saveable)
|
||||
if carry_uk_out == carry_uk:
|
||||
break
|
||||
else:
|
||||
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
assert not num_res
|
||||
body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts)
|
||||
del jaxpr_known_, carry_uk_out, num_res
|
||||
|
||||
# Compute the known part of cond_fun (basically pruning inputs on known side).
|
||||
cond_unks_in = cond_consts_uk + carry_uk
|
||||
cond_jaxpr_known_, _, [cond_uk], _, _ = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
cond_jaxpr.jaxpr, cond_unks_in, in_inst=True,
|
||||
ensure_out_unknowns=False, ensure_out_inst=True,
|
||||
saveable=ad_checkpoint.nothing_saveable)
|
||||
assert not cond_uk # only possible with old-style remat
|
||||
cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts)
|
||||
del cond_uk
|
||||
|
||||
# Build the known eqn.
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(carry_uk, eqn.outvars)
|
||||
params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known,
|
||||
cond_nconsts=len(cond_consts_uk) - sum(cond_consts_uk),
|
||||
body_nconsts=len(body_consts_uk) - sum(body_consts_uk))
|
||||
effects_known = core.join_effects(cond_jaxpr_known.effects,
|
||||
body_jaxpr_known.effects)
|
||||
eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p,
|
||||
params_known, effects_known, eqn.source_info)
|
||||
|
||||
# Staged eqn is same as input eqn.
|
||||
eqn_staged = eqn
|
||||
|
||||
# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
|
||||
unks_out = carry_uk
|
||||
inst_out = [True] * len(unks_out)
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst
|
||||
|
||||
def _while_transpose_error(*_, **kwargs):
|
||||
raise ValueError("Reverse-mode differentiation does not work for "
|
||||
"lax.while_loop or lax.fori_loop. "
|
||||
@ -1323,8 +1387,7 @@ pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_initial_style_primitive(while_p)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'while_loop')
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
|
||||
|
||||
|
||||
def _pred_bcast_select_mhlo(
|
||||
|
@ -1241,11 +1241,13 @@ call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
def partial_eval_jaxpr_custom(
|
||||
jaxpr: Jaxpr,
|
||||
in_unknowns: Sequence[bool],
|
||||
in_inst: Sequence[bool],
|
||||
in_inst: Union[bool, Sequence[bool]],
|
||||
ensure_out_unknowns: Union[bool, Sequence[bool]],
|
||||
ensure_out_inst: Union[bool, Sequence[bool]],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
if type(ensure_out_unknowns) is bool:
|
||||
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
|
||||
if type(ensure_out_inst) is bool:
|
||||
|
@ -5207,6 +5207,50 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_of_while_loop(self, remat):
|
||||
def cond_fn(carry):
|
||||
i, _ = carry
|
||||
return i < 3
|
||||
def body_fn(carry):
|
||||
i, x = carry
|
||||
return i + 1, jnp.sin(x)
|
||||
def f(x):
|
||||
_, y = lax.while_loop(cond_fn, body_fn, (0, x))
|
||||
return y
|
||||
|
||||
_, f_lin = jax.linearize(remat(f), 3.)
|
||||
y_dot = f_lin(1.0)
|
||||
expected = jax.grad(lambda x: jnp.sin(jnp.sin(jnp.sin(x))))(3.)
|
||||
self.assertArraysAllClose(y_dot, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
def test_remat_of_while_loop_policy(self):
|
||||
def cond_fn(carry):
|
||||
i, _ = carry
|
||||
return i < 3
|
||||
def body_fn(carry):
|
||||
i, x = carry
|
||||
return i + 1, jnp.sin(x)
|
||||
def f(x):
|
||||
_, y = lax.while_loop(cond_fn, body_fn, (0, x))
|
||||
return y
|
||||
|
||||
# even with a policy, we can't save residuals (w/o dynamic shapes)!
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
g = new_checkpoint(f, policy=save_cos)
|
||||
jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -105,6 +105,15 @@ SCAN_IMPLS_WITH_FOR = [
|
||||
]
|
||||
|
||||
|
||||
def while_loop_new_checkpoint(cond_fun, body_fun, init_val):
|
||||
return new_checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val)
|
||||
|
||||
WHILE_LOOP_IMPLS = [
|
||||
(lax.while_loop, 'while_loop'),
|
||||
(while_loop_new_checkpoint, 'new_checkpoint'),
|
||||
]
|
||||
|
||||
|
||||
def while_loop_reference(cond, body, carry):
|
||||
while cond(carry):
|
||||
carry = body(carry)
|
||||
@ -2007,13 +2016,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(loop, (x,), order=2, modes=["fwd"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_loop={}_jit_body={}_jit_cond={}".format(
|
||||
jit_loop, jit_body, jit_cond),
|
||||
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond}
|
||||
{"testcase_name": "_jit_loop={}_jit_body={}_jit_cond={}_impl={}".format(
|
||||
jit_loop, jit_body, jit_cond, while_name),
|
||||
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond,
|
||||
"while_loop": while_impl}
|
||||
for jit_loop in [False, True]
|
||||
for jit_body in [False, True]
|
||||
for jit_cond in [False, True])
|
||||
def testWhileLinearize(self, jit_loop=True, jit_body=False, jit_cond=True):
|
||||
for jit_cond in [False, True]
|
||||
for while_impl, while_name in WHILE_LOOP_IMPLS)
|
||||
def testWhileLinearize(self, while_loop, jit_loop=True, jit_body=False,
|
||||
jit_cond=True):
|
||||
cond = lambda x: x[0, 2] <= 8
|
||||
body = lambda x: x * x
|
||||
|
||||
@ -2022,7 +2034,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
if jit_body:
|
||||
body = jax.jit(body)
|
||||
|
||||
loop = partial(lax.while_loop, cond, body)
|
||||
loop = partial(while_loop, cond, body)
|
||||
if jit_loop:
|
||||
loop = jax.jit(loop)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user