mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Only call executale.get_output_memory_kinds() if jax_enable_memories is True
PiperOrigin-RevId: 584087022
This commit is contained in:
parent
ab9c973031
commit
d6a9352270
@ -2276,15 +2276,18 @@ def get_gspmd_shardings_from_executable(
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
|
||||
from jax._src import pjit
|
||||
|
||||
if all_default_mem_kind:
|
||||
omk = [None] * num_out_avals
|
||||
else:
|
||||
try:
|
||||
omk = xla_executable.get_output_memory_kinds()[0]
|
||||
if num_ordered_effects > 0:
|
||||
omk = omk[num_ordered_effects:]
|
||||
except:
|
||||
if config.enable_memories.value:
|
||||
if all_default_mem_kind:
|
||||
omk = [None] * num_out_avals
|
||||
else:
|
||||
try:
|
||||
omk = xla_executable.get_output_memory_kinds()[0]
|
||||
if num_ordered_effects > 0:
|
||||
omk = omk[num_ordered_effects:]
|
||||
except:
|
||||
omk = [None] * num_out_avals
|
||||
else:
|
||||
omk = [None] * num_out_avals
|
||||
|
||||
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user