mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13859 from mattjj:remat-named-policy-tweak
PiperOrigin-RevId: 499899379
This commit is contained in:
commit
5137277463
@ -71,8 +71,17 @@ def dot_with_no_batch_dims(prim, *_, **params) -> bool:
|
||||
|
||||
name_p = core.Primitive('name')
|
||||
|
||||
def save_anything_except_these_names(*names_not_to_save):
|
||||
"""Save any values (not just named ones) excluding the names given."""
|
||||
names_not_to_save = frozenset(names_not_to_save)
|
||||
def policy(prim, *_, **params):
|
||||
if prim is name_p:
|
||||
return params['name'] not in names_not_to_save
|
||||
return True # allow saving anything which is not named
|
||||
return policy
|
||||
|
||||
def save_any_names_but_these(*names_not_to_save):
|
||||
# Save named values, excluding the names given.
|
||||
"""Save only named values, excluding the names given."""
|
||||
names_not_to_save = frozenset(names_not_to_save)
|
||||
def policy(prim, *_, **params):
|
||||
if prim is name_p:
|
||||
@ -81,7 +90,7 @@ def save_any_names_but_these(*names_not_to_save):
|
||||
return policy
|
||||
|
||||
def save_only_these_names(*names_which_can_be_saved):
|
||||
# Save named values, only among the names given.
|
||||
"""Save only named values, and only among the names given."""
|
||||
names_which_can_be_saved = set(names_which_can_be_saved)
|
||||
def policy(prim, *_, **params):
|
||||
if prim is name_p:
|
||||
@ -103,6 +112,7 @@ checkpoint_policies = types.SimpleNamespace(
|
||||
nothing_saveable=nothing_saveable,
|
||||
checkpoint_dots=checkpoint_dots,
|
||||
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims,
|
||||
save_anything_except_these_names=save_anything_except_these_names,
|
||||
save_any_names_but_these=save_any_names_but_these,
|
||||
save_only_these_names=save_only_these_names,
|
||||
save_from_both_policies=save_from_both_policies)
|
||||
|
Loading…
x
Reference in New Issue
Block a user