Fix save_from_both_policies in presence of save_and_offload_only_these_names by comparing the enum

PiperOrigin-RevId: 706874882
This commit is contained in:
Yash Katariya 2024-12-16 16:40:26 -08:00 committed by jax authors
parent 772339ec60
commit 7dd401cb2a

View File

@ -142,8 +142,14 @@ def save_and_offload_only_these_names(
def save_from_both_policies(policy_1, policy_2):
def policy(prim, *args, **params):
return policy_1(prim, *args, **params) or policy_2(prim, *args, **params)
out1 = policy_1(prim, *args, **params)
out2 = policy_2(prim, *args, **params)
if not (isinstance(out1, bool) and isinstance(out2, bool)):
raise ValueError(
"The return value of the policies should be a boolean. Got:"
f" {out1} and {out2}. Please write a custom policy function directly,"
" rather than using this helper function.")
return out1 or out2
return policy