From d6a93522705bd345811d1470b1a0d46d9117f6f5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 20 Nov 2023 11:43:41 -0800 Subject: [PATCH] Only call executale.get_output_memory_kinds() if jax_enable_memories is True PiperOrigin-RevId: 584087022 --- jax/_src/interpreters/pxla.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 91360a8f8..58f6abcc2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2276,15 +2276,18 @@ def get_gspmd_shardings_from_executable( ) -> Sequence[sharding_impls.XLACompatibleSharding] | None: from jax._src import pjit - if all_default_mem_kind: - omk = [None] * num_out_avals - else: - try: - omk = xla_executable.get_output_memory_kinds()[0] - if num_ordered_effects > 0: - omk = omk[num_ordered_effects:] - except: + if config.enable_memories.value: + if all_default_mem_kind: omk = [None] * num_out_avals + else: + try: + omk = xla_executable.get_output_memory_kinds()[0] + if num_ordered_effects > 0: + omk = omk[num_ordered_effects:] + except: + omk = [None] * num_out_avals + else: + omk = [None] * num_out_avals assert len(omk) == num_out_avals, (len(omk), num_out_avals)