Only call executale.get_output_memory_kinds() if jax_enable_memories is True

PiperOrigin-RevId: 584087022
This commit is contained in:
Yash Katariya 2023-11-20 11:43:41 -08:00 committed by jax authors
parent ab9c973031
commit d6a9352270

View File

@ -2276,15 +2276,18 @@ def get_gspmd_shardings_from_executable(
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
from jax._src import pjit
if all_default_mem_kind:
omk = [None] * num_out_avals
else:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
if config.enable_memories.value:
if all_default_mem_kind:
omk = [None] * num_out_avals
else:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals
else:
omk = [None] * num_out_avals
assert len(omk) == num_out_avals, (len(omk), num_out_avals)