mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10576 from mattjj:new-remat-landing
PiperOrigin-RevId: 455831386
This commit is contained in:
commit
b50d77cc98
@ -388,6 +388,22 @@ def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
|
||||
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
|
||||
batching.axis_primitive_batchers[remat_p] = remat_vmap
|
||||
|
||||
# TODO(mattjj,sharadmv): test this more
|
||||
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
|
||||
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
||||
new_params = dict(eqn.params, jaxpr=new_jaxpr)
|
||||
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
||||
return used_inputs, None
|
||||
else:
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
pe.dce_rules[remat_p] = remat_dce
|
||||
|
||||
|
||||
def checkpoint_name(x, name):
|
||||
return name_p.bind(x, name=name)
|
||||
|
@ -46,6 +46,7 @@ from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (
|
||||
cache,
|
||||
extend_name_stack,
|
||||
partition_list,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_list,
|
||||
@ -829,6 +830,107 @@ def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
# TODO(mattjj): de-duplicate code with _scan_partial_eval
|
||||
def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
|
||||
# Fixpoint (currently trivial on 'inst_in')
|
||||
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
|
||||
for _ in range(1 + len(carry_uk)):
|
||||
unks_in = const_uk + carry_uk + xs_uk
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
|
||||
ensure_out_unknowns=carry_uk + [False] * num_ys,
|
||||
ensure_out_inst=True, saveable=saveable)
|
||||
carry_uk_out , ys_uk = split_list(unks_out, [num_carry])
|
||||
if carry_uk_out == carry_uk:
|
||||
break
|
||||
else:
|
||||
carry_uk = _map(operator.or_, carry_uk , carry_uk_out )
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
|
||||
|
||||
# Ensure residuals are all moved to the back.
|
||||
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
|
||||
res_avals = jaxpr_staged.in_avals[:num_res]
|
||||
jaxpr_staged = pe.move_binders_to_back(
|
||||
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
|
||||
|
||||
# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
inst_in = [True] * len(inst_in)
|
||||
|
||||
# As an optimization, hoist loop-invariant residuals out of the loop rather
|
||||
# than using extensive outputs for them. See _scan_partial_eval for comments.
|
||||
num_const_known = len(const_uk) - sum(const_uk)
|
||||
num_carry_known = len(carry_uk) - sum(carry_uk)
|
||||
num_xs_known = len( xs_uk) - sum( xs_uk)
|
||||
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, _ = \
|
||||
pe.partial_eval_jaxpr_nounits(
|
||||
jaxpr_known,
|
||||
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
|
||||
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
|
||||
# jaxpr_known_hoist produces intensive residuals followed by the constants for
|
||||
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
|
||||
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
|
||||
jaxpr_staged = pe.move_binders_to_front(
|
||||
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
|
||||
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
|
||||
del loop_dep, num_carry_known, num_xs_known
|
||||
|
||||
# Create residual variables.
|
||||
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
|
||||
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
|
||||
for a in ext_avals_mapped]
|
||||
newvar = core.gensym()
|
||||
intensive_res = _map(newvar, intensive_avals)
|
||||
extensive_res = _map(newvar, ext_avals)
|
||||
|
||||
# Create known eqn, which is a call_p combining evaluation of
|
||||
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
|
||||
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
|
||||
num_consts=len(const_uk)-sum(const_uk),
|
||||
num_carry=len(carry_uk)-sum(carry_uk),
|
||||
linear=tuple(linear_known))
|
||||
|
||||
@lu.wrap_init
|
||||
def known(*ins_known):
|
||||
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
|
||||
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
|
||||
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
|
||||
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
|
||||
return [*intensive_res, *out_loop]
|
||||
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||||
known, [v.aval for v in ins_known])
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
eqn_known = pe.new_jaxpr_eqn(
|
||||
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
|
||||
eqn.source_info)
|
||||
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
|
||||
[False] * len(extensive_res))
|
||||
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
|
||||
num_consts=len(intensive_res) + eqn.params['num_consts'],
|
||||
linear=tuple(linear_staged))
|
||||
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
|
||||
out_binders_staged, eqn.primitive,
|
||||
params_staged, jaxpr_staged.effects,
|
||||
eqn.source_info)
|
||||
|
||||
new_vars = [*new_inst, *intensive_res, *extensive_res]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||
|
||||
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll):
|
||||
avals = [x.aval for x in in_atoms]
|
||||
@ -899,8 +1001,7 @@ mlir.register_lowering(scan_p,
|
||||
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
|
||||
masking.masking_rules[scan_p] = _scan_masking_rule
|
||||
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
|
||||
|
@ -589,7 +589,11 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
||||
|
||||
|
||||
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
if isinstance(call_jaxpr, core.ClosedJaxpr):
|
||||
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
||||
else:
|
||||
consts = ()
|
||||
all_args, in_tree_def = tree_flatten((consts, args, ct))
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
|
||||
reduce_axes, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
|
@ -997,7 +997,6 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
return f_lowered
|
||||
|
||||
|
||||
|
||||
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
avals_out, tokens_in, *args):
|
||||
if isinstance(call_jaxpr, core.Jaxpr):
|
||||
@ -1041,6 +1040,9 @@ register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
|
||||
register_lowering(core.closed_call_p,
|
||||
partial(_named_call_lowering, name="core_closed_call"))
|
||||
|
||||
register_lowering(core.closed_call_p,
|
||||
partial(_named_call_lowering, name="core_closed_call"))
|
||||
|
||||
|
||||
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
|
||||
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
|
||||
|
@ -915,7 +915,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
|
||||
assert ([v.aval.strip_weak_type() for v in jaxpr_known.outvars] ==
|
||||
[a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns)
|
||||
if not uk] + [a.strip_weak_type() for a in res_avals])
|
||||
# check jaxpr_unknown has input type corresponding to unknown inputs plus res
|
||||
# check jaxpr_unknown has input type corresponding to res plus unknown inputs
|
||||
assert ([v.aval for v in jaxpr_unknown.invars] ==
|
||||
res_avals + [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk])
|
||||
# check jaxpr_unknown has output type corresponding to unknown outputs
|
||||
@ -1092,6 +1092,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
|
||||
known_eqns, staged_eqns = [], []
|
||||
map(write, in_unknowns, in_inst, jaxpr.invars)
|
||||
map(partial(write, False, True), jaxpr.constvars)
|
||||
for eqn in jaxpr.eqns:
|
||||
unks_in, inst_in = unzip2(map(read, eqn.invars))
|
||||
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
|
||||
@ -1277,17 +1278,20 @@ dce_rules: Dict[Primitive, DCERule] = {}
|
||||
|
||||
|
||||
def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
) -> Tuple[List[bool], JaxprEqn]:
|
||||
) -> Tuple[List[bool], Optional[JaxprEqn]]:
|
||||
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
||||
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
|
||||
update_params = call_param_updaters.get(eqn.primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, used_inputs, 0)
|
||||
new_eqn = new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
||||
return used_inputs, None
|
||||
else:
|
||||
new_eqn = new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[remat_call_p] = dce_jaxpr_call_rule
|
||||
|
@ -3895,7 +3895,13 @@ class RematTest(jtu.JaxTestCase):
|
||||
expected = api.grad(api.grad(f))(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_scan(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_scan(self, remat):
|
||||
to_scan = lambda c, x: (jnp.sin(c), None)
|
||||
|
||||
def f_noremat(x):
|
||||
@ -3903,7 +3909,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
return y
|
||||
|
||||
def f_yesremat(x):
|
||||
y, _ = lax.scan(api.remat(to_scan), x, np.arange(3.))
|
||||
y, _ = lax.scan(remat(to_scan), x, np.arange(3.))
|
||||
return y
|
||||
|
||||
ans = f_yesremat(4.)
|
||||
@ -3973,7 +3979,13 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
||||
|
||||
def test_remat_symbolic_zeros(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_symbolic_zeros(self, remat):
|
||||
# code from https://github.com/google/jax/issues/1907
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
@ -3994,7 +4006,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
F = apply_fn(R)
|
||||
return shift(R, 0.001 * F), jnp.array([0.])
|
||||
|
||||
move = api.remat(move)
|
||||
move = remat(move)
|
||||
R, temp = lax.scan(move, Rinit, jnp.arange(2))
|
||||
return R[0, 0]
|
||||
|
||||
@ -4020,10 +4032,16 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(f(3), 6, check_dtypes=False)
|
||||
|
||||
def test_remat_nontrivial_env(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_nontrivial_env(self, remat):
|
||||
# simplified from https://github.com/google/jax/issues/2030
|
||||
|
||||
@api.remat
|
||||
@remat
|
||||
def foo(state, dt=0.5, c=1):
|
||||
u, u_t = state
|
||||
u_tt = c**2 * u
|
||||
@ -4081,14 +4099,20 @@ class RematTest(jtu.JaxTestCase):
|
||||
f = remat(f)
|
||||
api.grad(f)(w, x) # doesn't crash
|
||||
|
||||
def test_remat_scan2(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_scan2(self, remat):
|
||||
# https://github.com/google/jax/issues/1963
|
||||
|
||||
def scan_bug(x0):
|
||||
f = lambda x, _: (x + 1, None)
|
||||
def scanned_f(x, _):
|
||||
return lax.scan(f, x, xs=None, length=1)[0], None
|
||||
x, _ = jax.remat(scanned_f)(x0, None)
|
||||
x, _ = remat(scanned_f)(x0, None)
|
||||
return x
|
||||
|
||||
jax.grad(scan_bug)(1.0) # doesn't crash
|
||||
@ -4219,8 +4243,14 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertTrue('while' in text or 'conditional' in text
|
||||
or 'opt-barrier' in text)
|
||||
|
||||
def test_no_cse_widget_with_prevent_cse_false(self):
|
||||
@partial(api.remat, prevent_cse=False)
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_no_cse_widget_with_prevent_cse_false(self, remat):
|
||||
@partial(remat, prevent_cse=False)
|
||||
def g(x):
|
||||
return lax.sin(lax.sin(x)), 3.
|
||||
|
||||
@ -4635,15 +4665,202 @@ class RematTest(jtu.JaxTestCase):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_remat_of_scan(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_of_scan(self, remat):
|
||||
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
|
||||
f = lambda x: lax.scan(to_scan, x, None, length=3)
|
||||
jtu.check_grads(jax.remat(f), (3.,), order=2, modes=['rev'])
|
||||
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(jax.remat(f), 4.)[1])(1.)
|
||||
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_const_in_jvp(self, remat):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return x * np.arange(3.)
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
(x,), (xdot,) = primals, tangents
|
||||
return f(x), xdot * np.arange(3.)
|
||||
|
||||
@remat
|
||||
def g(x):
|
||||
def body(c, _):
|
||||
return f(c), None
|
||||
y, _ = jax.lax.scan(body, x, None, length=1)
|
||||
return y.sum()
|
||||
|
||||
jax.grad(g)(jnp.arange(3.)) # doesn't crash
|
||||
|
||||
def test_remat_checkpoint_dots_outside_scan(self):
|
||||
# see also above test test_remat_checkpoint_dots_inside_scan
|
||||
x = jnp.ones((5,))
|
||||
|
||||
@partial(new_checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(W):
|
||||
def f(x):
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
return x
|
||||
|
||||
def body(x, _): return f(x), None
|
||||
return lax.scan(body, x, None, length=2)[0]
|
||||
|
||||
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
# Six calls to dot_general in the backward pass because we save the primal
|
||||
# matmuls and only compure the backward pass ones (two for each primal one).
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
modes=['fwd', 'rev'])
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_scan_policy(self):
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
|
||||
f = new_checkpoint(lambda x: lax.scan(to_scan, x, None, length=3),
|
||||
policy=save_cos)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_scan_funky_custom_jvp(self):
|
||||
def scan_apply(f, x):
|
||||
y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1)
|
||||
return y
|
||||
|
||||
@api.custom_jvp
|
||||
def sin(x):
|
||||
return jnp.sin(x)
|
||||
def sin_jvp(primals, tangents):
|
||||
x, = primals
|
||||
xdot, = tangents
|
||||
y, c = jax.jit(lambda: (jnp.sin(x), jnp.cos(x)))()
|
||||
ydot = c * xdot
|
||||
return y, ydot
|
||||
sin.defjvp(sin_jvp)
|
||||
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = new_checkpoint(partial(scan_apply, sin), policy=save_cos)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_scan_funky_custom_jvp2(self):
|
||||
# Like the above test but instead of using jit inside custom_jvp, use scan.
|
||||
|
||||
def scan_apply(f, x):
|
||||
y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1)
|
||||
return y
|
||||
|
||||
@api.custom_jvp
|
||||
def sin(x):
|
||||
return jnp.sin(x)
|
||||
def sin_jvp(primals, tangents):
|
||||
x, = primals
|
||||
xdot, = tangents
|
||||
y, c = scan_apply(lambda xs: (jnp.sin(xs[0]), jnp.cos(xs[1])), (x, x))
|
||||
ydot = c * xdot
|
||||
return y, ydot
|
||||
sin.defjvp(sin_jvp)
|
||||
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = new_checkpoint(partial(scan_apply, sin), policy=save_cos)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -36,6 +36,7 @@ from jax import tree_util
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental import maps
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
@ -59,6 +60,17 @@ def cond_via_switch(pred, true_fun, false_fun, op, *args):
|
||||
return lax.switch(index, [false_fun, true_fun], op)
|
||||
|
||||
|
||||
# We wanted to try all scan tests with the scan partial evaluation rule that
|
||||
# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which
|
||||
# wraps a ad_checkpoint.checkpoint around the computation.
|
||||
def scan_with_new_checkpoint(f, *args, **kwargs):
|
||||
return new_checkpoint(partial(lax.scan, f, **kwargs),
|
||||
policy=checkpoint_policies.nothing_saveable)(*args)
|
||||
def scan_with_new_checkpoint2(f, *args, **kwargs):
|
||||
return new_checkpoint(partial(lax.scan, f, **kwargs),
|
||||
policy=checkpoint_policies.everything_saveable)(*args)
|
||||
|
||||
|
||||
COND_IMPLS = [
|
||||
(lax.cond, 'cond'),
|
||||
(cond_via_switch, 'switch'),
|
||||
@ -68,6 +80,8 @@ COND_IMPLS = [
|
||||
SCAN_IMPLS = [
|
||||
(lax.scan, 'unroll1'),
|
||||
(partial(lax.scan, unroll=2), 'unroll2'),
|
||||
(scan_with_new_checkpoint , 'new_checkpoint'),
|
||||
(scan_with_new_checkpoint2, 'new_checkpoint2'),
|
||||
]
|
||||
|
||||
|
||||
@ -1534,10 +1548,14 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
as_ = rng.randn(5, 3)
|
||||
c = rng.randn(4)
|
||||
|
||||
if scan is scan_with_new_checkpoint2:
|
||||
rtol = {np.float64: 1e-12, np.float32: 1e-4}
|
||||
else:
|
||||
rtol = {np.float64: 1e-14, np.float32: 1e-4}
|
||||
|
||||
ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
|
||||
expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||||
rtol={np.float64: 1e-14, np.float32: 1e-4})
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
|
||||
@ -1569,11 +1587,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
ans = jax.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
|
||||
expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||||
rtol={np.float32: 2e-5, np.float64: 1e-13})
|
||||
if scan is scan_with_new_checkpoint:
|
||||
rtol = {np.float32: 5e-5, np.float64: 1e-13}
|
||||
else:
|
||||
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol)
|
||||
|
||||
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
|
||||
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
|
||||
atol=1e-3, rtol=5e-3)
|
||||
atol=1e-3, rtol=rtol)
|
||||
|
||||
@jtu.skip_on_devices("tpu") # TPU lacks precision for this test.
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
@ -1642,7 +1664,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
batched_inputs, batched_targets)))
|
||||
self.assertAllClose(losses, expected, check_dtypes=False, rtol=1e-2)
|
||||
|
||||
def testIssue711(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_impl={}".format(scan_name), "scan": scan_impl}
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def testIssue711(self, scan):
|
||||
# Tests reverse-mode differentiation through a scan for which the scanned
|
||||
# function also involves reverse-mode differentiation.
|
||||
# See https://github.com/google/jax/issues/711
|
||||
@ -1659,7 +1684,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return new_carry, _
|
||||
|
||||
x0 = jnp.array([1., 2., 3.])
|
||||
carry_final, _ = lax.scan(apply_carry, (0, x0), jnp.zeros((75, 0)))
|
||||
carry_final, _ = scan(apply_carry, (0, x0), jnp.zeros((75, 0)))
|
||||
_, x_final = carry_final
|
||||
return x_final
|
||||
|
||||
@ -1799,13 +1824,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
ans = jax.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testScanVmapFixpoint(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_impl={}".format(scan_name), "scan": scan_impl}
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def testScanVmapFixpoint(self, scan):
|
||||
def f(carry_init):
|
||||
def scan_body(c, x):
|
||||
# The carry is a 4-tuple, the last element starts batched,
|
||||
# and the carry is shifted left at each iteration.
|
||||
return ((c[1], c[2], c[3], 0.), None)
|
||||
return lax.scan(scan_body, (0., 1., 2., carry_init), jnp.zeros(2))
|
||||
return scan(scan_body, (0., 1., 2., carry_init), jnp.zeros(2))
|
||||
carry_init = jnp.array([3., 4., 5.])
|
||||
carry_out, _ = jax.vmap(f)(carry_init)
|
||||
self.assertAllClose(carry_out[3], jnp.array([0., 0., 0.]), check_dtypes=False)
|
||||
@ -2338,21 +2366,33 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
result = lax.while_loop(cond_fun, body_fun, init_weak)
|
||||
self.assertArraysEqual(result, jnp.full_like(increment, 2))
|
||||
|
||||
def test_scan_vjp_forwards_extensive_residuals(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', None),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_scan_vjp_forwards_extensive_residuals(self, remat):
|
||||
# https://github.com/google/jax/issues/4510
|
||||
def cumprod(x):
|
||||
s = jnp.ones((2, 32), jnp.float32)
|
||||
return lax.scan(lambda s, x: (x*s, s), s, x)
|
||||
if remat is not None:
|
||||
cumprod = remat(cumprod)
|
||||
|
||||
rng = self.rng()
|
||||
x = jnp.asarray(rng.randn(32, 2, 32).astype('float32'))
|
||||
_, vjp_fun = jax.vjp(cumprod, x)
|
||||
|
||||
# Need to spelunk into vjp_fun. This is fragile, and if it causes problems
|
||||
# just skip this test.
|
||||
# just skip this test and make an issue for mattjj.
|
||||
*_, ext_res = vjp_fun.args[0].args[0]
|
||||
self.assertIs(ext_res, x)
|
||||
|
||||
if remat is not None:
|
||||
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
|
||||
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
|
||||
|
||||
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
|
||||
_, vjp_fun = jax.vjp(cumprod, x)
|
||||
*_, ext_res = vjp_fun.args[0].args[0]
|
||||
|
@ -48,6 +48,7 @@ from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -1673,7 +1674,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
||||
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
def testAxisIndexRemat(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', jax.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def testAxisIndexRemat(self, remat):
|
||||
# https://github.com/google/jax/issues/2716
|
||||
n = len(jax.devices())
|
||||
|
||||
@ -1682,7 +1689,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return random.bernoulli(key, p=0.5)
|
||||
|
||||
keys = random.split(random.PRNGKey(0), n)
|
||||
self.pmap(jax.remat(f), axis_name='i')(keys)
|
||||
self.pmap(remat(f), axis_name='i')(keys)
|
||||
|
||||
def testPmapMapVmapCombinations(self):
|
||||
# https://github.com/google/jax/issues/2822
|
||||
|
Loading…
x
Reference in New Issue
Block a user