tweak remat named policies

This commit is contained in:
Matthew Johnson 2023-01-04 08:25:57 -08:00
parent a1c699f59d
commit 4f7cf622d4

View File

@ -71,8 +71,17 @@ def dot_with_no_batch_dims(prim, *_, **params) -> bool:
name_p = core.Primitive('name')
def save_anything_except_these_names(*names_not_to_save):
"""Save any values (not just named ones) excluding the names given."""
names_not_to_save = frozenset(names_not_to_save)
def policy(prim, *_, **params):
if prim is name_p:
return params['name'] not in names_not_to_save
return True # allow saving anything which is not named
return policy
def save_any_names_but_these(*names_not_to_save):
# Save named values, excluding the names given.
"""Save only named values, excluding the names given."""
names_not_to_save = frozenset(names_not_to_save)
def policy(prim, *_, **params):
if prim is name_p:
@ -81,7 +90,7 @@ def save_any_names_but_these(*names_not_to_save):
return policy
def save_only_these_names(*names_which_can_be_saved):
# Save named values, only among the names given.
"""Save only named values, and only among the names given."""
names_which_can_be_saved = set(names_which_can_be_saved)
def policy(prim, *_, **params):
if prim is name_p:
@ -103,6 +112,7 @@ checkpoint_policies = types.SimpleNamespace(
nothing_saveable=nothing_saveable,
checkpoint_dots=checkpoint_dots,
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims,
save_anything_except_these_names=save_anything_except_these_names,
save_any_names_but_these=save_any_names_but_these,
save_only_these_names=save_only_these_names,
save_from_both_policies=save_from_both_policies)