mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove dead codepaths now that MemorySpaceDescription works in OSS
PiperOrigin-RevId: 715410774
This commit is contained in:
parent
ee724565bf
commit
b7e06f1937
@ -2368,9 +2368,6 @@ def lower_sharding_computation(
|
||||
out_layouts=out_layouts,
|
||||
pmap_nreps=nreps,
|
||||
shape_poly_state=shape_poly_state,
|
||||
# TODO(yashkatariya): Remove `all_default_mem_kind` after
|
||||
# MemoryDescription works in OSS.
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
all_args_info=all_args_info,
|
||||
pgle_profiler=pgle_profiler,
|
||||
intermediate_shardings=unique_intermediate_shardings,
|
||||
@ -2442,21 +2439,15 @@ def get_out_shardings_from_executable(
|
||||
device_assignment: Sequence[xc.Device],
|
||||
num_out_avals: int,
|
||||
num_ordered_effects: int,
|
||||
all_default_mem_kind: bool,
|
||||
) -> Sequence[sharding_impls.GSPMDSharding] | None:
|
||||
from jax._src import pjit
|
||||
|
||||
# TODO(yashkatariya): Remove `all_default_mem_kind` branch after
|
||||
# MemoryDescription works in OSS.
|
||||
if all_default_mem_kind:
|
||||
try:
|
||||
omk = xla_executable.get_output_memory_kinds()[0]
|
||||
if num_ordered_effects > 0:
|
||||
omk = omk[num_ordered_effects:]
|
||||
except:
|
||||
omk = [None] * num_out_avals
|
||||
else:
|
||||
try:
|
||||
omk = xla_executable.get_output_memory_kinds()[0]
|
||||
if num_ordered_effects > 0:
|
||||
omk = omk[num_ordered_effects:]
|
||||
except:
|
||||
omk = [None] * num_out_avals
|
||||
|
||||
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
|
||||
|
||||
@ -2781,11 +2772,11 @@ def _maybe_get_and_check_in_shardings(
|
||||
|
||||
def _maybe_get_and_check_out_shardings(
|
||||
xla_executable, out_shardings, device_assignment, global_out_avals,
|
||||
num_ordered_effects, all_default_mem_kind
|
||||
num_ordered_effects
|
||||
):
|
||||
out_shardings_xla = get_out_shardings_from_executable(
|
||||
xla_executable, device_assignment, len(global_out_avals),
|
||||
num_ordered_effects, all_default_mem_kind)
|
||||
num_ordered_effects)
|
||||
if out_shardings_xla is None:
|
||||
return out_shardings
|
||||
|
||||
@ -2893,7 +2884,6 @@ class UnloadedMeshExecutable:
|
||||
pmap_nreps: int = 1,
|
||||
mut: MutationData | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
all_default_mem_kind: bool = True,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
@ -2951,7 +2941,7 @@ class UnloadedMeshExecutable:
|
||||
len(ordered_effects))
|
||||
out_shardings = _maybe_get_and_check_out_shardings(
|
||||
xla_executable, out_shardings, tuple(da), global_out_avals,
|
||||
len(ordered_effects), all_default_mem_kind)
|
||||
len(ordered_effects))
|
||||
else:
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
|
Loading…
x
Reference in New Issue
Block a user