mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix save_from_both_policies
in presence of save_and_offload_only_these_names
by comparing the enum
PiperOrigin-RevId: 706874882
This commit is contained in:
parent
772339ec60
commit
7dd401cb2a
@ -142,8 +142,14 @@ def save_and_offload_only_these_names(
|
||||
def save_from_both_policies(policy_1, policy_2):
|
||||
|
||||
def policy(prim, *args, **params):
|
||||
return policy_1(prim, *args, **params) or policy_2(prim, *args, **params)
|
||||
|
||||
out1 = policy_1(prim, *args, **params)
|
||||
out2 = policy_2(prim, *args, **params)
|
||||
if not (isinstance(out1, bool) and isinstance(out2, bool)):
|
||||
raise ValueError(
|
||||
"The return value of the policies should be a boolean. Got:"
|
||||
f" {out1} and {out2}. Please write a custom policy function directly,"
|
||||
" rather than using this helper function.")
|
||||
return out1 or out2
|
||||
return policy
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user