Make looking up shardings from executable consistent. If out_shardings are specified on jit, always check it against the get_output_shardings from the executable.

PiperOrigin-RevId: 583456869
This commit is contained in:
Yash Katariya 2023-11-17 12:18:46 -08:00 committed by jax authors
parent 8e8dc263bc
commit 38729552fa
2 changed files with 55 additions and 34 deletions

View File

@ -2273,7 +2273,7 @@ def get_gspmd_shardings_from_executable(
num_out_avals: int,
num_ordered_effects: int,
all_default_mem_kind: bool,
) -> Sequence[sharding_impls.XLACompatibleSharding]:
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
from jax._src import pjit
if all_default_mem_kind:
@ -2297,6 +2297,11 @@ def get_gspmd_shardings_from_executable(
for mk in omk]
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
if not out_op_shardings:
return None
if num_ordered_effects > 0:
out_op_shardings = out_op_shardings[num_ordered_effects:]
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
@ -2516,6 +2521,39 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
return xla_executable, compile_options
def _get_shardings_from_executable(
xla_executable, out_shardings, device_assignment, global_out_avals,
num_ordered_effects, all_default_mem_kind
):
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals),
num_ordered_effects, all_default_mem_kind) # type: ignore
if out_shardings_xla is None:
return out_shardings, (False,) * len(global_out_avals)
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
global_out_avals):
if is_unspecified(orig):
out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
xla_s.memory_kind != orig.memory_kind)): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
out_shardings.append(orig)
are_out_shardings_from_xla.append(False)
return out_shardings, are_out_shardings_from_xla
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
@ -2636,34 +2674,18 @@ class UnloadedMeshExecutable:
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
# TODO(yashkatariya): Make da directly usable in the downstream code
# without tuple conversion.
device_assignment = tuple(da)
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals),
len(ordered_effects), all_default_mem_kind) # type: ignore
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
global_out_avals):
if is_unspecified(orig):
out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
xla_s.memory_kind != orig.memory_kind): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
out_shardings.append(orig)
are_out_shardings_from_xla.append(False)
else:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if pmap_nreps == 1:
assert mesh is None
# TODO(yashkatariya): Make da directly usable in the downstream code
# without tuple conversion.
out_shardings, are_out_shardings_from_xla = _get_shardings_from_executable(
xla_executable, out_shardings, tuple(da), global_out_avals,
len(ordered_effects), all_default_mem_kind)
else:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if xla_extension_version >= 215:
in_layouts, out_layouts = _get_layouts_from_executable(
@ -2672,10 +2694,6 @@ class UnloadedMeshExecutable:
assert all(i is None for i in in_layouts)
assert all(o is None for o in out_layouts)
if pmap_nreps > 1:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla,
global_in_avals, global_out_avals)

View File

@ -2049,7 +2049,8 @@ def _fast_path_get_device_assignment(
return da
def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
def _get_partition_spec(
ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
return [get_single_pspec(p) for p in ppspec]
@ -2068,7 +2069,9 @@ def get_op_sharding_from_executable(
return in_op_shardings, out_op_shardings
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
def _get_ppspec_from_executable(
executable, mesh
) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
input_op_shardings, output_op_sharding = get_op_sharding_from_executable(
executable
)