Remove dead codepaths now that MemorySpaceDescription works in OSS

PiperOrigin-RevId: 715410774
This commit is contained in:
Yash Katariya 2025-01-14 09:21:57 -08:00 committed by jax authors
parent ee724565bf
commit b7e06f1937

View File

@ -2368,9 +2368,6 @@ def lower_sharding_computation(
out_layouts=out_layouts,
pmap_nreps=nreps,
shape_poly_state=shape_poly_state,
# TODO(yashkatariya): Remove `all_default_mem_kind` after
# MemoryDescription works in OSS.
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler,
intermediate_shardings=unique_intermediate_shardings,
@ -2442,21 +2439,15 @@ def get_out_shardings_from_executable(
device_assignment: Sequence[xc.Device],
num_out_avals: int,
num_ordered_effects: int,
all_default_mem_kind: bool,
) -> Sequence[sharding_impls.GSPMDSharding] | None:
from jax._src import pjit
# TODO(yashkatariya): Remove `all_default_mem_kind` branch after
# MemoryDescription works in OSS.
if all_default_mem_kind:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals
else:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
@ -2781,11 +2772,11 @@ def _maybe_get_and_check_in_shardings(
def _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, device_assignment, global_out_avals,
num_ordered_effects, all_default_mem_kind
num_ordered_effects
):
out_shardings_xla = get_out_shardings_from_executable(
xla_executable, device_assignment, len(global_out_avals),
num_ordered_effects, all_default_mem_kind)
num_ordered_effects)
if out_shardings_xla is None:
return out_shardings
@ -2893,7 +2884,6 @@ class UnloadedMeshExecutable:
pmap_nreps: int = 1,
mut: MutationData | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
@ -2951,7 +2941,7 @@ class UnloadedMeshExecutable:
len(ordered_effects))
out_shardings = _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, tuple(da), global_out_avals,
len(ordered_effects), all_default_mem_kind)
len(ordered_effects))
else:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))