Merge pull request #10576 from mattjj:new-remat-landing

PiperOrigin-RevId: 455831386
This commit is contained in:
jax authors 2022-06-18 12:33:52 -07:00
commit b50d77cc98
8 changed files with 428 additions and 37 deletions

View File

@ -388,6 +388,22 @@ def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = remat_vmap
# TODO(mattjj,sharadmv): test this more
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
pe.dce_rules[remat_p] = remat_dce
def checkpoint_name(x, name):
return name_p.bind(x, name=name)

View File

@ -46,6 +46,7 @@ from jax._src.traceback_util import api_boundary
from jax._src.util import (
cache,
extend_name_stack,
partition_list,
safe_map,
safe_zip,
split_list,
@ -829,6 +830,107 @@ def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn
# TODO(mattjj): de-duplicate code with _scan_partial_eval
def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_ys = len(jaxpr.out_avals) - num_carry
# Fixpoint (currently trivial on 'inst_in')
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
for _ in range(1 + len(carry_uk)):
unks_in = const_uk + carry_uk + xs_uk
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out , ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk , carry_uk_out )
else:
assert False, "Fixpoint not reached"
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
# Ensure residuals are all moved to the back.
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
res_avals = jaxpr_staged.in_avals[:num_res]
jaxpr_staged = pe.move_binders_to_back(
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
# As an optimization, hoist loop-invariant residuals out of the loop rather
# than using extensive outputs for them. See _scan_partial_eval for comments.
num_const_known = len(const_uk) - sum(const_uk)
num_carry_known = len(carry_uk) - sum(carry_uk)
num_xs_known = len( xs_uk) - sum( xs_uk)
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, _ = \
pe.partial_eval_jaxpr_nounits(
jaxpr_known,
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
# jaxpr_known_hoist produces intensive residuals followed by the constants for
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
jaxpr_staged = pe.move_binders_to_front(
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
del loop_dep, num_carry_known, num_xs_known
# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
extensive_res = _map(newvar, ext_avals)
# Create known eqn, which is a call_p combining evaluation of
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
num_consts=len(const_uk)-sum(const_uk),
num_carry=len(carry_uk)-sum(carry_uk),
linear=tuple(linear_known))
@lu.wrap_init
def known(*ins_known):
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
return [*intensive_res, *out_loop]
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in ins_known])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
eqn.source_info)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
[False] * len(extensive_res))
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
num_consts=len(intensive_res) + eqn.params['num_consts'],
linear=tuple(linear_staged))
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
out_binders_staged, eqn.primitive,
params_staged, jaxpr_staged.effects,
eqn.source_info)
new_vars = [*new_inst, *intensive_res, *extensive_res]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
avals = [x.aval for x in in_atoms]
@ -899,8 +1001,7 @@ mlir.register_lowering(scan_p,
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
masking.masking_rules[scan_p] = _scan_masking_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule

View File

@ -589,7 +589,11 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
if isinstance(call_jaxpr, core.ClosedJaxpr):
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
else:
consts = ()
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)

View File

@ -997,7 +997,6 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
return f_lowered
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args):
if isinstance(call_jaxpr, core.Jaxpr):
@ -1041,6 +1040,9 @@ register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""

View File

@ -915,7 +915,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
assert ([v.aval.strip_weak_type() for v in jaxpr_known.outvars] ==
[a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns)
if not uk] + [a.strip_weak_type() for a in res_avals])
# check jaxpr_unknown has input type corresponding to unknown inputs plus res
# check jaxpr_unknown has input type corresponding to res plus unknown inputs
assert ([v.aval for v in jaxpr_unknown.invars] ==
res_avals + [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk])
# check jaxpr_unknown has output type corresponding to unknown outputs
@ -1092,6 +1092,7 @@ def _partial_eval_jaxpr_custom_cached(
known_eqns, staged_eqns = [], []
map(write, in_unknowns, in_inst, jaxpr.invars)
map(partial(write, False, True), jaxpr.constvars)
for eqn in jaxpr.eqns:
unks_in, inst_in = unzip2(map(read, eqn.invars))
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
@ -1277,17 +1278,20 @@ dce_rules: Dict[Primitive, DCERule] = {}
def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
) -> Tuple[List[bool], Optional[JaxprEqn]]:
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(eqn.primitive)
if update_params:
new_params = update_params(new_params, used_inputs, 0)
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
dce_rules[remat_call_p] = dce_jaxpr_call_rule

View File

@ -3895,7 +3895,13 @@ class RematTest(jtu.JaxTestCase):
expected = api.grad(api.grad(f))(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_scan(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_scan(self, remat):
to_scan = lambda c, x: (jnp.sin(c), None)
def f_noremat(x):
@ -3903,7 +3909,7 @@ class RematTest(jtu.JaxTestCase):
return y
def f_yesremat(x):
y, _ = lax.scan(api.remat(to_scan), x, np.arange(3.))
y, _ = lax.scan(remat(to_scan), x, np.arange(3.))
return y
ans = f_yesremat(4.)
@ -3973,7 +3979,13 @@ class RematTest(jtu.JaxTestCase):
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
def test_remat_symbolic_zeros(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_symbolic_zeros(self, remat):
# code from https://github.com/google/jax/issues/1907
key = jax.random.PRNGKey(0)
@ -3994,7 +4006,7 @@ class RematTest(jtu.JaxTestCase):
F = apply_fn(R)
return shift(R, 0.001 * F), jnp.array([0.])
move = api.remat(move)
move = remat(move)
R, temp = lax.scan(move, Rinit, jnp.arange(2))
return R[0, 0]
@ -4020,10 +4032,16 @@ class RematTest(jtu.JaxTestCase):
self.assertAllClose(f(3), 6, check_dtypes=False)
def test_remat_nontrivial_env(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_nontrivial_env(self, remat):
# simplified from https://github.com/google/jax/issues/2030
@api.remat
@remat
def foo(state, dt=0.5, c=1):
u, u_t = state
u_tt = c**2 * u
@ -4081,14 +4099,20 @@ class RematTest(jtu.JaxTestCase):
f = remat(f)
api.grad(f)(w, x) # doesn't crash
def test_remat_scan2(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_scan2(self, remat):
# https://github.com/google/jax/issues/1963
def scan_bug(x0):
f = lambda x, _: (x + 1, None)
def scanned_f(x, _):
return lax.scan(f, x, xs=None, length=1)[0], None
x, _ = jax.remat(scanned_f)(x0, None)
x, _ = remat(scanned_f)(x0, None)
return x
jax.grad(scan_bug)(1.0) # doesn't crash
@ -4219,8 +4243,14 @@ class RematTest(jtu.JaxTestCase):
self.assertTrue('while' in text or 'conditional' in text
or 'opt-barrier' in text)
def test_no_cse_widget_with_prevent_cse_false(self):
@partial(api.remat, prevent_cse=False)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_no_cse_widget_with_prevent_cse_false(self, remat):
@partial(remat, prevent_cse=False)
def g(x):
return lax.sin(lax.sin(x)), 3.
@ -4635,15 +4665,202 @@ class RematTest(jtu.JaxTestCase):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
def test_remat_of_scan(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_of_scan(self, remat):
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
f = lambda x: lax.scan(to_scan, x, None, length=3)
jtu.check_grads(jax.remat(f), (3.,), order=2, modes=['rev'])
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(jax.remat(f), 4.)[1])(1.)
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_const_in_jvp(self, remat):
@api.custom_jvp
def f(x):
return x * np.arange(3.)
@f.defjvp
def f_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return f(x), xdot * np.arange(3.)
@remat
def g(x):
def body(c, _):
return f(c), None
y, _ = jax.lax.scan(body, x, None, length=1)
return y.sum()
jax.grad(g)(jnp.arange(3.)) # doesn't crash
def test_remat_checkpoint_dots_outside_scan(self):
# see also above test test_remat_checkpoint_dots_inside_scan
x = jnp.ones((5,))
@partial(new_checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def f(W):
def f(x):
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
return x
def body(x, _): return f(x), None
return lax.scan(body, x, None, length=2)[0]
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 3)
self.assertEqual(jaxpr_text.count(' cos '), 3)
# Six calls to dot_general in the backward pass because we save the primal
# matmuls and only compure the backward pass ones (two for each primal one).
self.assertEqual(jaxpr_text.count(' dot_'), 6)
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
modes=['fwd', 'rev'])
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_scan_policy(self):
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
f = new_checkpoint(lambda x: lax.scan(to_scan, x, None, length=3),
policy=save_cos)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_scan_funky_custom_jvp(self):
def scan_apply(f, x):
y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1)
return y
@api.custom_jvp
def sin(x):
return jnp.sin(x)
def sin_jvp(primals, tangents):
x, = primals
xdot, = tangents
y, c = jax.jit(lambda: (jnp.sin(x), jnp.cos(x)))()
ydot = c * xdot
return y, ydot
sin.defjvp(sin_jvp)
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = new_checkpoint(partial(scan_apply, sin), policy=save_cos)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 2)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_scan_funky_custom_jvp2(self):
# Like the above test but instead of using jit inside custom_jvp, use scan.
def scan_apply(f, x):
y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1)
return y
@api.custom_jvp
def sin(x):
return jnp.sin(x)
def sin_jvp(primals, tangents):
x, = primals
xdot, = tangents
y, c = scan_apply(lambda xs: (jnp.sin(xs[0]), jnp.cos(xs[1])), (x, x))
ydot = c * xdot
return y, ydot
sin.defjvp(sin_jvp)
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = new_checkpoint(partial(scan_apply, sin), policy=save_cos)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 2)
class JaxprTest(jtu.JaxTestCase):

View File

@ -36,6 +36,7 @@ from jax import tree_util
from jax._src.util import unzip2
from jax.experimental import maps
from jax.interpreters import partial_eval as pe
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
from jax._src.lax.control_flow import for_loop
@ -59,6 +60,17 @@ def cond_via_switch(pred, true_fun, false_fun, op, *args):
return lax.switch(index, [false_fun, true_fun], op)
# We wanted to try all scan tests with the scan partial evaluation rule that
# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which
# wraps a ad_checkpoint.checkpoint around the computation.
def scan_with_new_checkpoint(f, *args, **kwargs):
return new_checkpoint(partial(lax.scan, f, **kwargs),
policy=checkpoint_policies.nothing_saveable)(*args)
def scan_with_new_checkpoint2(f, *args, **kwargs):
return new_checkpoint(partial(lax.scan, f, **kwargs),
policy=checkpoint_policies.everything_saveable)(*args)
COND_IMPLS = [
(lax.cond, 'cond'),
(cond_via_switch, 'switch'),
@ -68,6 +80,8 @@ COND_IMPLS = [
SCAN_IMPLS = [
(lax.scan, 'unroll1'),
(partial(lax.scan, unroll=2), 'unroll2'),
(scan_with_new_checkpoint , 'new_checkpoint'),
(scan_with_new_checkpoint2, 'new_checkpoint2'),
]
@ -1534,10 +1548,14 @@ class LaxControlFlowTest(jtu.JaxTestCase):
as_ = rng.randn(5, 3)
c = rng.randn(4)
if scan is scan_with_new_checkpoint2:
rtol = {np.float64: 1e-12, np.float32: 1e-4}
else:
rtol = {np.float64: 1e-14, np.float32: 1e-4}
ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
self.assertAllClose(ans, expected, check_dtypes=False,
rtol={np.float64: 1e-14, np.float32: 1e-4})
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol)
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
@ -1569,11 +1587,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
ans = jax.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
self.assertAllClose(ans, expected, check_dtypes=False,
rtol={np.float32: 2e-5, np.float64: 1e-13})
if scan is scan_with_new_checkpoint:
rtol = {np.float32: 5e-5, np.float64: 1e-13}
else:
rtol = {np.float32: 2e-5, np.float64: 1e-13}
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol)
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
atol=1e-3, rtol=5e-3)
atol=1e-3, rtol=rtol)
@jtu.skip_on_devices("tpu") # TPU lacks precision for this test.
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@ -1642,7 +1664,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
batched_inputs, batched_targets)))
self.assertAllClose(losses, expected, check_dtypes=False, rtol=1e-2)
def testIssue711(self):
@parameterized.named_parameters(
{"testcase_name": "_impl={}".format(scan_name), "scan": scan_impl}
for scan_impl, scan_name in SCAN_IMPLS)
def testIssue711(self, scan):
# Tests reverse-mode differentiation through a scan for which the scanned
# function also involves reverse-mode differentiation.
# See https://github.com/google/jax/issues/711
@ -1659,7 +1684,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return new_carry, _
x0 = jnp.array([1., 2., 3.])
carry_final, _ = lax.scan(apply_carry, (0, x0), jnp.zeros((75, 0)))
carry_final, _ = scan(apply_carry, (0, x0), jnp.zeros((75, 0)))
_, x_final = carry_final
return x_final
@ -1799,13 +1824,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
ans = jax.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
def testScanVmapFixpoint(self):
@parameterized.named_parameters(
{"testcase_name": "_impl={}".format(scan_name), "scan": scan_impl}
for scan_impl, scan_name in SCAN_IMPLS)
def testScanVmapFixpoint(self, scan):
def f(carry_init):
def scan_body(c, x):
# The carry is a 4-tuple, the last element starts batched,
# and the carry is shifted left at each iteration.
return ((c[1], c[2], c[3], 0.), None)
return lax.scan(scan_body, (0., 1., 2., carry_init), jnp.zeros(2))
return scan(scan_body, (0., 1., 2., carry_init), jnp.zeros(2))
carry_init = jnp.array([3., 4., 5.])
carry_out, _ = jax.vmap(f)(carry_init)
self.assertAllClose(carry_out[3], jnp.array([0., 0., 0.]), check_dtypes=False)
@ -2338,21 +2366,33 @@ class LaxControlFlowTest(jtu.JaxTestCase):
result = lax.while_loop(cond_fun, body_fun, init_weak)
self.assertArraysEqual(result, jnp.full_like(increment, 2))
def test_scan_vjp_forwards_extensive_residuals(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', None),
('new_remat', new_checkpoint),
])
def test_scan_vjp_forwards_extensive_residuals(self, remat):
# https://github.com/google/jax/issues/4510
def cumprod(x):
s = jnp.ones((2, 32), jnp.float32)
return lax.scan(lambda s, x: (x*s, s), s, x)
if remat is not None:
cumprod = remat(cumprod)
rng = self.rng()
x = jnp.asarray(rng.randn(32, 2, 32).astype('float32'))
_, vjp_fun = jax.vjp(cumprod, x)
# Need to spelunk into vjp_fun. This is fragile, and if it causes problems
# just skip this test.
# just skip this test and make an issue for mattjj.
*_, ext_res = vjp_fun.args[0].args[0]
self.assertIs(ext_res, x)
if remat is not None:
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
_, vjp_fun = jax.vjp(cumprod, x)
*_, ext_res = vjp_fun.args[0].args[0]

View File

@ -48,6 +48,7 @@ from jax.interpreters import pxla
from jax.interpreters import xla
from jax.experimental import array
from jax.experimental.sharding import PmapSharding
from jax.ad_checkpoint import checkpoint as new_checkpoint
from jax.config import config
config.parse_flags_with_absl()
@ -1673,7 +1674,13 @@ class PythonPmapTest(jtu.JaxTestCase):
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
def testAxisIndexRemat(self):
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', jax.remat),
('_new', new_checkpoint),
])
def testAxisIndexRemat(self, remat):
# https://github.com/google/jax/issues/2716
n = len(jax.devices())
@ -1682,7 +1689,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return random.bernoulli(key, p=0.5)
keys = random.split(random.PRNGKey(0), n)
self.pmap(jax.remat(f), axis_name='i')(keys)
self.pmap(remat(f), axis_name='i')(keys)
def testPmapMapVmapCombinations(self):
# https://github.com/google/jax/issues/2822