From 2761f266d5875ff1180af08c3e4ea86a5d46cf43 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 1 Mar 2024 09:27:57 -0800 Subject: [PATCH] Set out_mut to `None` as default on `from_hlo` instead of in `__init__` of `MeshComputation` and correct the types too. PiperOrigin-RevId: 611814102 --- jax/_src/interpreters/pxla.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 29d4713c1..3a03cbec3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -891,7 +891,7 @@ class UnloadedPmapExecutable: self.unordered_effects, self.ordered_effects, self.keepalive, bool(self.host_callbacks), - set(range(len(input_indices))), []) + set(range(len(input_indices))), None) return execute_fun def load(self) -> PmapExecutable: @@ -1155,7 +1155,7 @@ class ExecuteReplicated: unordered_effects: list[core.Effect], ordered_effects: list[core.Effect], keepalive: Any, has_host_callbacks: bool, kept_var_idx: set[int], - out_mut: Sequence[int | None]): + out_mut: Sequence[int | None] | None): self.xla_executable = xla_executable self.name = name self.backend = backend @@ -1210,7 +1210,7 @@ class ExecuteReplicated: out = self.out_handler(out_arrays) else: out = results.consume_with_handlers(self.out_handler.handlers) - if not self.out_mut: + if self.out_mut is None: return out else: out_ = [] @@ -2282,7 +2282,6 @@ def lower_mesh_computation( host_callbacks=lowering_result.host_callbacks, keepalive=lowering_result.keepalive, kept_var_idx=set(range(len(global_in_avals))), - out_mut=None, backend=backend, device_assignment=_create_da_object(tuple(mesh.devices.flat)), committed=True, @@ -2297,7 +2296,6 @@ class MeshComputation(stages.XlaLowering): def __init__(self, name: str, hlo: ir.Module | None, donated_invars: Sequence[bool], **compile_args): - compile_args.setdefault('out_mut', None) # TODO(mattjj): remove default self._name = name self._hlo = hlo self._donated_invars = donated_invars @@ -2763,7 +2761,7 @@ class UnloadedMeshExecutable: keepalive: Sequence[Any] host_callbacks: Sequence[Any] kept_var_idx: set[int] - out_mut: Sequence[None | int] + out_mut: Sequence[None | int] | None auto_spmd_lowering: bool in_layouts: Sequence[SpecifiedLayout | None] out_layouts: Sequence[SpecifiedLayout | None] @@ -2802,7 +2800,7 @@ class UnloadedMeshExecutable: global_out_avals: Sequence[ShapedArray], in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO], out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO | - UnspecifiedValue)], + UnspecifiedValue)], spmd_lowering: bool, tuple_args: bool, auto_spmd_lowering: bool, @@ -2811,13 +2809,13 @@ class UnloadedMeshExecutable: host_callbacks: list[Any], keepalive: Any, kept_var_idx: set[int], - out_mut: Sequence[None | int], backend: xb.XlaBackend, device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, pmap_nreps: int = 1, + out_mut: Sequence[None | int] | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None,