diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 91360a8f8..58f6abcc2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2276,15 +2276,18 @@ def get_gspmd_shardings_from_executable( ) -> Sequence[sharding_impls.XLACompatibleSharding] | None: from jax._src import pjit - if all_default_mem_kind: - omk = [None] * num_out_avals - else: - try: - omk = xla_executable.get_output_memory_kinds()[0] - if num_ordered_effects > 0: - omk = omk[num_ordered_effects:] - except: + if config.enable_memories.value: + if all_default_mem_kind: omk = [None] * num_out_avals + else: + try: + omk = xla_executable.get_output_memory_kinds()[0] + if num_ordered_effects > 0: + omk = omk[num_ordered_effects:] + except: + omk = [None] * num_out_avals + else: + omk = [None] * num_out_avals assert len(omk) == num_out_avals, (len(omk), num_out_avals)