Set out_mut to None as default on from_hlo instead of in __init__ of MeshComputation and correct the types too.

PiperOrigin-RevId: 611814102
This commit is contained in:
Yash Katariya 2024-03-01 09:27:57 -08:00 committed by jax authors
parent cfeb1130dd
commit 2761f266d5

View File

@ -891,7 +891,7 @@ class UnloadedPmapExecutable:
self.unordered_effects,
self.ordered_effects, self.keepalive,
bool(self.host_callbacks),
set(range(len(input_indices))), [])
set(range(len(input_indices))), None)
return execute_fun
def load(self) -> PmapExecutable:
@ -1155,7 +1155,7 @@ class ExecuteReplicated:
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: set[int],
out_mut: Sequence[int | None]):
out_mut: Sequence[int | None] | None):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
@ -1210,7 +1210,7 @@ class ExecuteReplicated:
out = self.out_handler(out_arrays)
else:
out = results.consume_with_handlers(self.out_handler.handlers)
if not self.out_mut:
if self.out_mut is None:
return out
else:
out_ = []
@ -2282,7 +2282,6 @@ def lower_mesh_computation(
host_callbacks=lowering_result.host_callbacks,
keepalive=lowering_result.keepalive,
kept_var_idx=set(range(len(global_in_avals))),
out_mut=None,
backend=backend,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True,
@ -2297,7 +2296,6 @@ class MeshComputation(stages.XlaLowering):
def __init__(self, name: str, hlo: ir.Module | None,
donated_invars: Sequence[bool], **compile_args):
compile_args.setdefault('out_mut', None) # TODO(mattjj): remove default
self._name = name
self._hlo = hlo
self._donated_invars = donated_invars
@ -2763,7 +2761,7 @@ class UnloadedMeshExecutable:
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: set[int]
out_mut: Sequence[None | int]
out_mut: Sequence[None | int] | None
auto_spmd_lowering: bool
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
@ -2802,7 +2800,7 @@ class UnloadedMeshExecutable:
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
UnspecifiedValue)],
UnspecifiedValue)],
spmd_lowering: bool,
tuple_args: bool,
auto_spmd_lowering: bool,
@ -2811,13 +2809,13 @@ class UnloadedMeshExecutable:
host_callbacks: list[Any],
keepalive: Any,
kept_var_idx: set[int],
out_mut: Sequence[None | int],
backend: xb.XlaBackend,
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
pmap_nreps: int = 1,
out_mut: Sequence[None | int] | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,