From 38729552fac8bc71d4dd4d79b9843e44d99b17b6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Nov 2023 12:18:46 -0800 Subject: [PATCH] Make looking up shardings from executable consistent. If `out_shardings` are specified on `jit`, always check it against the `get_output_shardings` from the executable. PiperOrigin-RevId: 583456869 --- jax/_src/interpreters/pxla.py | 82 +++++++++++++++++++++-------------- jax/_src/pjit.py | 7 ++- 2 files changed, 55 insertions(+), 34 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 234bf5540..18195a813 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2273,7 +2273,7 @@ def get_gspmd_shardings_from_executable( num_out_avals: int, num_ordered_effects: int, all_default_mem_kind: bool, -) -> Sequence[sharding_impls.XLACompatibleSharding]: +) -> Sequence[sharding_impls.XLACompatibleSharding] | None: from jax._src import pjit if all_default_mem_kind: @@ -2297,6 +2297,11 @@ def get_gspmd_shardings_from_executable( for mk in omk] _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) + if not out_op_shardings: + return None + + if num_ordered_effects > 0: + out_op_shardings = out_op_shardings[num_ordered_effects:] # This condition happens when all the elements in the output tuple have the # same sharding, so XLA decides to run the `FusionTupleDeduplicator` to @@ -2516,6 +2521,39 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, return xla_executable, compile_options +def _get_shardings_from_executable( + xla_executable, out_shardings, device_assignment, global_out_avals, + num_ordered_effects, all_default_mem_kind + ): + out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore + xla_executable, device_assignment, len(global_out_avals), + num_ordered_effects, all_default_mem_kind) # type: ignore + if out_shardings_xla is None: + return out_shardings, (False,) * len(global_out_avals) + + orig_out_shardings = out_shardings + out_shardings, are_out_shardings_from_xla = [], [] # type: ignore + for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, + global_out_avals): + if is_unspecified(orig): + out_shardings.append(xla_s) + are_out_shardings_from_xla.append(True) + else: + xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore + orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore + # MANUAL HloSharding comes from other partitioning frameworks. + if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and + not xla_hlo_s.is_manual() and + (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or + xla_s.memory_kind != orig.memory_kind)): # type: ignore + raise AssertionError( + f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " + "(User sharding)") + out_shardings.append(orig) + are_out_shardings_from_xla.append(False) + return out_shardings, are_out_shardings_from_xla + + @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any @@ -2636,34 +2674,18 @@ class UnloadedMeshExecutable: for x, o in safe_zip(out_shardings_xla, out_shardings) ] out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) - elif (out_shardings and any(is_unspecified(o) for o in out_shardings) - and pmap_nreps == 1): - assert mesh is None - # TODO(yashkatariya): Make da directly usable in the downstream code - # without tuple conversion. - device_assignment = tuple(da) - out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore - xla_executable, device_assignment, len(global_out_avals), - len(ordered_effects), all_default_mem_kind) # type: ignore - orig_out_shardings = out_shardings - out_shardings, are_out_shardings_from_xla = [], [] # type: ignore - for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, - global_out_avals): - if is_unspecified(orig): - out_shardings.append(xla_s) - are_out_shardings_from_xla.append(True) - else: - xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore - orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore - if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or - xla_s.memory_kind != orig.memory_kind): # type: ignore - raise AssertionError( - f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " - "(User sharding)") - out_shardings.append(orig) - are_out_shardings_from_xla.append(False) else: - are_out_shardings_from_xla = (False,) * len(global_out_avals) + if pmap_nreps == 1: + assert mesh is None + # TODO(yashkatariya): Make da directly usable in the downstream code + # without tuple conversion. + out_shardings, are_out_shardings_from_xla = _get_shardings_from_executable( + xla_executable, out_shardings, tuple(da), global_out_avals, + len(ordered_effects), all_default_mem_kind) + else: + in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( + xla_executable.local_devices(), len(in_shardings), len(out_shardings)) + are_out_shardings_from_xla = (False,) * len(global_out_avals) if xla_extension_version >= 215: in_layouts, out_layouts = _get_layouts_from_executable( @@ -2672,10 +2694,6 @@ class UnloadedMeshExecutable: assert all(i is None for i in in_layouts) assert all(o is None for o in out_layouts) - if pmap_nreps > 1: - in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( - xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding( in_shardings, out_shardings, are_out_shardings_from_xla, global_in_avals, global_out_avals) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ca39ca3e1..c61dd5c93 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2049,7 +2049,8 @@ def _fast_path_get_device_assignment( return da -def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]: +def _get_partition_spec( + ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]: return [get_single_pspec(p) for p in ppspec] @@ -2068,7 +2069,9 @@ def get_op_sharding_from_executable( return in_op_shardings, out_op_shardings -def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]: +def _get_ppspec_from_executable( + executable, mesh + ) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]: input_op_shardings, output_op_sharding = get_op_sharding_from_executable( executable )