Merge pull request #13859 from mattjj:remat-named-policy-tweak

PiperOrigin-RevId: 499899379
This commit is contained in:
jax authors 2023-01-05 09:07:50 -08:00
commit 5137277463

View File

@ -71,8 +71,17 @@ def dot_with_no_batch_dims(prim, *_, **params) -> bool:
name_p = core.Primitive('name')
def save_anything_except_these_names(*names_not_to_save):
"""Save any values (not just named ones) excluding the names given."""
names_not_to_save = frozenset(names_not_to_save)
def policy(prim, *_, **params):
if prim is name_p:
return params['name'] not in names_not_to_save
return True # allow saving anything which is not named
return policy
def save_any_names_but_these(*names_not_to_save):
# Save named values, excluding the names given.
"""Save only named values, excluding the names given."""
names_not_to_save = frozenset(names_not_to_save)
def policy(prim, *_, **params):
if prim is name_p:
@ -81,7 +90,7 @@ def save_any_names_but_these(*names_not_to_save):
return policy
def save_only_these_names(*names_which_can_be_saved):
# Save named values, only among the names given.
"""Save only named values, and only among the names given."""
names_which_can_be_saved = set(names_which_can_be_saved)
def policy(prim, *_, **params):
if prim is name_p:
@ -103,6 +112,7 @@ checkpoint_policies = types.SimpleNamespace(
nothing_saveable=nothing_saveable,
checkpoint_dots=checkpoint_dots,
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims,
save_anything_except_these_names=save_anything_except_these_names,
save_any_names_but_these=save_any_names_but_these,
save_only_these_names=save_only_these_names,
save_from_both_policies=save_from_both_policies)