mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make looking up shardings from executable consistent. If out_shardings
are specified on jit
, always check it against the get_output_shardings
from the executable.
PiperOrigin-RevId: 583456869
This commit is contained in:
parent
8e8dc263bc
commit
38729552fa
@ -2273,7 +2273,7 @@ def get_gspmd_shardings_from_executable(
|
||||
num_out_avals: int,
|
||||
num_ordered_effects: int,
|
||||
all_default_mem_kind: bool,
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
|
||||
from jax._src import pjit
|
||||
|
||||
if all_default_mem_kind:
|
||||
@ -2297,6 +2297,11 @@ def get_gspmd_shardings_from_executable(
|
||||
for mk in omk]
|
||||
|
||||
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
|
||||
if not out_op_shardings:
|
||||
return None
|
||||
|
||||
if num_ordered_effects > 0:
|
||||
out_op_shardings = out_op_shardings[num_ordered_effects:]
|
||||
|
||||
# This condition happens when all the elements in the output tuple have the
|
||||
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
||||
@ -2516,6 +2521,39 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
return xla_executable, compile_options
|
||||
|
||||
|
||||
def _get_shardings_from_executable(
|
||||
xla_executable, out_shardings, device_assignment, global_out_avals,
|
||||
num_ordered_effects, all_default_mem_kind
|
||||
):
|
||||
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment, len(global_out_avals),
|
||||
num_ordered_effects, all_default_mem_kind) # type: ignore
|
||||
if out_shardings_xla is None:
|
||||
return out_shardings, (False,) * len(global_out_avals)
|
||||
|
||||
orig_out_shardings = out_shardings
|
||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
||||
global_out_avals):
|
||||
if is_unspecified(orig):
|
||||
out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
# MANUAL HloSharding comes from other partitioning frameworks.
|
||||
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
||||
not xla_hlo_s.is_manual() and
|
||||
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
|
||||
xla_s.memory_kind != orig.memory_kind)): # type: ignore
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
out_shardings.append(orig)
|
||||
are_out_shardings_from_xla.append(False)
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UnloadedMeshExecutable:
|
||||
xla_executable: Any
|
||||
@ -2636,34 +2674,18 @@ class UnloadedMeshExecutable:
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)
|
||||
]
|
||||
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
|
||||
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
||||
and pmap_nreps == 1):
|
||||
assert mesh is None
|
||||
# TODO(yashkatariya): Make da directly usable in the downstream code
|
||||
# without tuple conversion.
|
||||
device_assignment = tuple(da)
|
||||
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment, len(global_out_avals),
|
||||
len(ordered_effects), all_default_mem_kind) # type: ignore
|
||||
orig_out_shardings = out_shardings
|
||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
||||
global_out_avals):
|
||||
if is_unspecified(orig):
|
||||
out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
|
||||
xla_s.memory_kind != orig.memory_kind): # type: ignore
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
out_shardings.append(orig)
|
||||
are_out_shardings_from_xla.append(False)
|
||||
else:
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
if pmap_nreps == 1:
|
||||
assert mesh is None
|
||||
# TODO(yashkatariya): Make da directly usable in the downstream code
|
||||
# without tuple conversion.
|
||||
out_shardings, are_out_shardings_from_xla = _get_shardings_from_executable(
|
||||
xla_executable, out_shardings, tuple(da), global_out_avals,
|
||||
len(ordered_effects), all_default_mem_kind)
|
||||
else:
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
if xla_extension_version >= 215:
|
||||
in_layouts, out_layouts = _get_layouts_from_executable(
|
||||
@ -2672,10 +2694,6 @@ class UnloadedMeshExecutable:
|
||||
assert all(i is None for i in in_layouts)
|
||||
assert all(o is None for o in out_layouts)
|
||||
|
||||
if pmap_nreps > 1:
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
|
||||
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, are_out_shardings_from_xla,
|
||||
global_in_avals, global_out_avals)
|
||||
|
@ -2049,7 +2049,8 @@ def _fast_path_get_device_assignment(
|
||||
return da
|
||||
|
||||
|
||||
def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
|
||||
def _get_partition_spec(
|
||||
ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
|
||||
return [get_single_pspec(p) for p in ppspec]
|
||||
|
||||
|
||||
@ -2068,7 +2069,9 @@ def get_op_sharding_from_executable(
|
||||
return in_op_shardings, out_op_shardings
|
||||
|
||||
|
||||
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
def _get_ppspec_from_executable(
|
||||
executable, mesh
|
||||
) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
input_op_shardings, output_op_sharding = get_op_sharding_from_executable(
|
||||
executable
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user