mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
improve partial_eval_jaxpr_custom
* add caching via weakref_lru_cache * add inst_in argument (needed for fixedpoints for loop primitives, in follow-up PR), update callers not to over-instantiate inputs (previously I had used a convention where call primitives would just stage out eqns with all inputs instantiated, for expediene) * add ensure_out_unknowns and ensure_out_inst arguments, analogues of `instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for fixpoints of loop primitives) * better dce in remat_partial_eval (e.g. prune unused residuals)
This commit is contained in:
parent
705e241409
commit
7e241b682d
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import operator as op
|
||||
from typing import Callable, Optional, List, Tuple
|
||||
import types
|
||||
|
||||
@ -29,7 +30,7 @@ from jax._src import source_info_util
|
||||
from jax._src.api_util import flatten_fun
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
safe_zip)
|
||||
safe_zip, merge_lists)
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
||||
@ -302,43 +303,46 @@ ad.primitive_jvps[remat_p] = remat_jvp
|
||||
|
||||
def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
policy = params['policy'] or (lambda *_, **__: False)
|
||||
# unzip into jaxpr_known and jaxpr_unknown
|
||||
policy = params['policy'] or nothing_saveable
|
||||
in_unknowns = [not t.is_known() for t in tracers]
|
||||
# TODO(mattjj): use cached version of pe.partial_eval_jaxpr_custom
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
|
||||
pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
|
||||
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
|
||||
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
|
||||
jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)
|
||||
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)
|
||||
|
||||
# DCE jaxpr_staged, keeping only instantiated outputs which are unknown
|
||||
_, out_inst_unknown = partition_list(out_inst, out_unknowns)
|
||||
jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged, out_inst_unknown)
|
||||
used_res, in_used_staged = split_list(in_used_staged, [num_res])
|
||||
|
||||
# DCE jaxpr_known, keeping all known outputs but discarding dce'd res
|
||||
out_used_known = [True] * (len(out_unknowns) - sum(out_unknowns)) + used_res
|
||||
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
|
||||
num_res = sum(used_res)
|
||||
|
||||
# compute known outputs and residuals (hoisted out of remat primitive)
|
||||
_, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
|
||||
_, in_consts = partition_list(in_used_known, in_consts_)
|
||||
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
|
||||
out_consts_ = iter(out_consts)
|
||||
# form known outputs and collect residual tracers
|
||||
out_known_tracers = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
|
||||
for uk in out_unknowns if not uk]
|
||||
residuals = list(out_consts_)
|
||||
out_knowns, residuals = split_list(out_consts, [len(out_consts)-num_res])
|
||||
|
||||
# set up unknown outputs with a recipe to call remat
|
||||
res_tracers = map(trace.new_instantiated_const, residuals)
|
||||
in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
|
||||
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
|
||||
_, tracers_staged = partition_list(in_used_staged, tracers)
|
||||
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged)
|
||||
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
|
||||
for x in jaxpr_unknown.outvars]
|
||||
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
|
||||
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
|
||||
new_params, jaxpr_unknown.effects, source_info_util.current())
|
||||
new_params, jaxpr_unknown.effects,
|
||||
source_info_util.current())
|
||||
for t in out_jaxpr_tracers: t.recipe = recipe
|
||||
|
||||
# zip together known and unknown outputs
|
||||
return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
|
||||
return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
|
||||
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval
|
||||
|
||||
def remat_partial_eval_custom_params_updater(_, __, ___, ____, params_known, params_staged):
|
||||
def remat_partial_eval_custom_params_updater(*args):
|
||||
*_, params_known, params_staged = args
|
||||
return params_known, dict(params_staged, differentiated=True)
|
||||
pe.partial_eval_jaxpr_custom_rules[remat_p] = \
|
||||
partial(pe.call_partial_eval_custom_rule, 'jaxpr',
|
||||
|
@ -1057,12 +1057,12 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
|
||||
|
||||
def _xmap_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool],
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
) -> Tuple[dict, dict]:
|
||||
assert params_known['spmd_in_axes'] is None and params_known['spmd_out_axes'] is None
|
||||
assert params_staged['spmd_in_axes'] is None and params_staged['spmd_out_axes'] is None
|
||||
assert params_known['spmd_in_axes'] is None is params_known['spmd_out_axes']
|
||||
assert params_staged['spmd_in_axes'] is None is params_staged['spmd_out_axes']
|
||||
|
||||
# pruned inputs to jaxpr_known according to unks_in
|
||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||
@ -1085,13 +1085,15 @@ def _xmap_partial_eval_custom_params_updater(
|
||||
assert len(new_params_known['in_axes']) == len(params_known['call_jaxpr'].invars)
|
||||
assert len(new_params_known['out_axes']) == len(params_known['call_jaxpr'].outvars)
|
||||
|
||||
# added num_res new inputs to jaxpr_staged
|
||||
donated_invars_staged = (*(False for _ in range(num_res)), *params_staged['donated_invars'])
|
||||
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
||||
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
||||
donated_invars_staged = [False] * num_res + donated_invars_staged
|
||||
_, in_axes_staged = pe.partition_list(inst_in, params_staged['in_axes'])
|
||||
in_axes_staged = [*residual_axes, *in_axes_staged]
|
||||
_, out_axes_staged = pe.partition_list(kept_outs_staged, params_staged['out_axes'])
|
||||
new_params_staged = dict(params_staged,
|
||||
in_axes=(*residual_axes, *params_staged['in_axes']),
|
||||
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
|
||||
out_axes=tuple(out_axes_staged),
|
||||
donated_invars=donated_invars_staged)
|
||||
donated_invars=tuple(donated_invars_staged))
|
||||
assert len(new_params_staged['in_axes']) == len(params_staged['call_jaxpr'].invars)
|
||||
assert len(new_params_staged['out_axes']) == len(params_staged['call_jaxpr'].outvars)
|
||||
return new_params_known, new_params_staged
|
||||
|
@ -938,8 +938,9 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
|
||||
if params['policy']:
|
||||
# unzip into jaxpr_known and jaxpr_unknown
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
|
||||
jaxpr, in_unknowns, params['policy'])
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
|
||||
partial_eval_jaxpr_custom(jaxpr, in_unknowns, [True] * len(in_unknowns),
|
||||
False, False, params['policy'])
|
||||
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
|
||||
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
|
||||
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
|
||||
@ -1008,20 +1009,31 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
|
||||
def _partition_knowns(pvals, unknowns: Sequence[bool]):
|
||||
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
|
||||
[e for e, unknown in zip(pvals, unknowns) if unknown])
|
||||
def partial_eval_jaxpr_custom(
|
||||
jaxpr: Jaxpr,
|
||||
in_unknowns: Sequence[bool],
|
||||
in_inst: Sequence[bool],
|
||||
ensure_out_unknowns: Union[bool, Sequence[bool]],
|
||||
ensure_out_inst: Union[bool, Sequence[bool]],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
|
||||
if type(ensure_out_unknowns) is bool:
|
||||
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
|
||||
if type(ensure_out_inst) is bool:
|
||||
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
|
||||
return _partial_eval_jaxpr_custom_cached(
|
||||
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
|
||||
tuple(ensure_out_inst), saveable)
|
||||
|
||||
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
|
||||
assert len(known_list) + len(unknown_list) == len(which_unknown)
|
||||
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
|
||||
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
|
||||
|
||||
|
||||
def _partial_eval_jaxpr_custom(
|
||||
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
|
||||
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
|
||||
@weakref_lru_cache
|
||||
def _partial_eval_jaxpr_custom_cached(
|
||||
jaxpr: Jaxpr,
|
||||
in_unknowns: Tuple[bool, ...],
|
||||
in_inst: Tuple[bool, ...],
|
||||
ensure_out_unknowns: Tuple[bool, ...],
|
||||
ensure_out_inst: Tuple[bool, ...],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
|
||||
env: Dict[Var, Tuple[bool, bool]] = {}
|
||||
residuals: OrderedSet[Var] = OrderedSet()
|
||||
|
||||
@ -1040,7 +1052,7 @@ def _partial_eval_jaxpr_custom(
|
||||
return x
|
||||
|
||||
known_eqns, staged_eqns = [], []
|
||||
map(write, in_unknowns, [True] * len(in_unknowns), jaxpr.invars)
|
||||
map(write, in_unknowns, in_inst, jaxpr.invars)
|
||||
for eqn in jaxpr.eqns:
|
||||
unks_in, inst_in = unzip2(map(read, eqn.invars))
|
||||
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
|
||||
@ -1064,17 +1076,23 @@ def _partial_eval_jaxpr_custom(
|
||||
out_unknowns, out_inst = unzip2(map(read, jaxpr.outvars))
|
||||
assert all(type(v) is Var for v in residuals), residuals
|
||||
|
||||
for x, inst, ensure_inst in zip(jaxpr.outvars, out_inst, ensure_out_inst):
|
||||
if ensure_inst: ensure_instantiated(inst, x)
|
||||
out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns)
|
||||
out_inst = map(op.or_, out_inst, ensure_out_inst)
|
||||
|
||||
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
|
||||
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
|
||||
known_effects = core.join_effects(*(eqn.effects for eqn in known_eqns))
|
||||
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns,
|
||||
known_effects)
|
||||
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, [*outs_known, *residuals],
|
||||
known_eqns, known_effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
|
||||
|
||||
_, ins_staged = partition_list(in_inst, jaxpr.invars)
|
||||
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
|
||||
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
|
||||
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged,
|
||||
staged_eqns, staged_effects)
|
||||
jaxpr_staged = Jaxpr(jaxpr.constvars, [*residuals, *ins_staged],
|
||||
outs_staged, staged_eqns, staged_effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
|
||||
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
|
||||
@ -1105,7 +1123,8 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
|
||||
|
||||
|
||||
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
|
||||
int, dict, dict], Tuple[dict, dict]]
|
||||
Sequence[bool], int, dict, dict],
|
||||
Tuple[dict, dict]]
|
||||
|
||||
def call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
@ -1114,21 +1133,21 @@ def call_partial_eval_custom_rule(
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||
_partial_eval_jaxpr_custom(jaxpr, unks_in, saveable)
|
||||
partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
kept_outs_staged = inst_out
|
||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
|
||||
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
|
||||
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
|
||||
params_known, params_staged = params_updater(
|
||||
unks_in, map(op.not_, unks_out), kept_outs_staged, num_res, params_known,
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
|
||||
params_staged)
|
||||
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
@ -1137,13 +1156,14 @@ def call_partial_eval_custom_rule(
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, x, y: (x, y))
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, x, y: (x, y))
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[remat_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, p1, p2: (p1, dict(p2, differentiated=True)))
|
||||
lambda _, __, ___, ____, _____, p1, p2:
|
||||
(p1, dict(p2, differentiated=True)))
|
||||
|
||||
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
|
@ -433,16 +433,17 @@ ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
|
||||
|
||||
def _xla_call_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool],
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
) -> Tuple[dict, dict]:
|
||||
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
|
||||
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
||||
new_params_known = dict(params_known, donated_invars=tuple(donated_invars_known))
|
||||
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
||||
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
|
||||
# added num_res new inputs to jaxpr_staged, so extend donated_invars
|
||||
donated_invars_staged = [*([False] * num_res), *params_staged['donated_invars']]
|
||||
new_params_staged = dict(params_staged, donated_invars=tuple(donated_invars_staged))
|
||||
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
|
||||
donated_staged = [False] * num_res + donated_staged_
|
||||
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
|
||||
return new_params_known, new_params_staged
|
||||
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
|
||||
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
|
@ -4007,17 +4007,20 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertNotIn('conditional', c.as_hlo_text())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{policy_name}", "policy": policy,
|
||||
"in_jaxpr2": in_jaxpr2, "not_in_jaxpr2": not_in_jaxpr2}
|
||||
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
|
||||
"policy": policy, "in_jaxpr2": in_jaxpr2, "not_in_jaxpr2": not_in_jaxpr2}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
]
|
||||
for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [
|
||||
('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']),
|
||||
('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []),
|
||||
('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']),
|
||||
])
|
||||
def test_remat_custom_policy(self, policy, in_jaxpr2, not_in_jaxpr2):
|
||||
def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2):
|
||||
for square in [lambda x: x * x, api.jit(lambda x: x * x)]:
|
||||
f = api.remat(lambda x: jnp.sin(square(jnp.sin(x))),
|
||||
policy=policy)
|
||||
f = remat(lambda x: jnp.sin(square(jnp.sin(x))), policy=policy)
|
||||
y, f_lin = api.linearize(f, 1.)
|
||||
ydot = f_lin(2.)
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
@ -4031,18 +4034,30 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ydot, ydot_expected)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_custom_policy_save_cos(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{remat_name}", "remat": remat}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_remat_custom_policy_save_cos(self, remat):
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = api.remat(lambda x: jnp.sin(jnp.sin(x)), # different function
|
||||
policy=save_cos)
|
||||
f = remat(lambda x: jnp.sin(jnp.sin(x)), # different function
|
||||
policy=save_cos)
|
||||
_, f_lin = api.linearize(f, 1.)
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
self.assertNotIn(' sin ', jaxpr_text)
|
||||
self.assertNotIn(' cos ', jaxpr_text)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_checkpoint_dots(self):
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{remat_name}", "remat": remat}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_remat_checkpoint_dots(self, remat):
|
||||
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
@ -4058,8 +4073,14 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_checkpoint_dots_with_no_batch_dims(self):
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{remat_name}", "remat": remat}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_remat_checkpoint_dots_with_no_batch_dims(self, remat):
|
||||
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
def f(x):
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
@ -4075,8 +4096,14 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' dot_general'), 6)
|
||||
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_checkpoint_dots_with_no_batch_dims2(self):
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{remat_name}", "remat": remat}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_remat_checkpoint_dots_with_no_batch_dims2(self, remat):
|
||||
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
def f(x):
|
||||
x = jnp.einsum('nij,njk->nik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
@ -4092,9 +4119,15 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' dot_general'), 9)
|
||||
jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_checkpoint_dots_jit(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{remat_name}", "remat": remat}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_remat_checkpoint_dots_jit(self, remat):
|
||||
@api.jit
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x * 1e-3)
|
||||
@ -4196,11 +4229,17 @@ class RematTest(jtu.JaxTestCase):
|
||||
return lax.scan(lambda x, _: (f(x), None), x, None, length=2)[0]
|
||||
jtu.check_grads(g, (3.,), order=2, modes=['rev'])
|
||||
|
||||
def test_remat_dropvar_policy(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{remat_name}", "remat": remat}
|
||||
for remat_name, remat in [
|
||||
('old_remat', api.remat),
|
||||
('new_remat', new_checkpoint),
|
||||
])
|
||||
def test_remat_dropvar_policy(self, remat):
|
||||
def f(x):
|
||||
return x, x
|
||||
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def g(x):
|
||||
x = api.grad(lambda x: f(x)[0])(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user