mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Set out_mut to None
as default on from_hlo
instead of in __init__
of MeshComputation
and correct the types too.
PiperOrigin-RevId: 611814102
This commit is contained in:
parent
cfeb1130dd
commit
2761f266d5
@ -891,7 +891,7 @@ class UnloadedPmapExecutable:
|
||||
self.unordered_effects,
|
||||
self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks),
|
||||
set(range(len(input_indices))), [])
|
||||
set(range(len(input_indices))), None)
|
||||
return execute_fun
|
||||
|
||||
def load(self) -> PmapExecutable:
|
||||
@ -1155,7 +1155,7 @@ class ExecuteReplicated:
|
||||
unordered_effects: list[core.Effect],
|
||||
ordered_effects: list[core.Effect], keepalive: Any,
|
||||
has_host_callbacks: bool, kept_var_idx: set[int],
|
||||
out_mut: Sequence[int | None]):
|
||||
out_mut: Sequence[int | None] | None):
|
||||
self.xla_executable = xla_executable
|
||||
self.name = name
|
||||
self.backend = backend
|
||||
@ -1210,7 +1210,7 @@ class ExecuteReplicated:
|
||||
out = self.out_handler(out_arrays)
|
||||
else:
|
||||
out = results.consume_with_handlers(self.out_handler.handlers)
|
||||
if not self.out_mut:
|
||||
if self.out_mut is None:
|
||||
return out
|
||||
else:
|
||||
out_ = []
|
||||
@ -2282,7 +2282,6 @@ def lower_mesh_computation(
|
||||
host_callbacks=lowering_result.host_callbacks,
|
||||
keepalive=lowering_result.keepalive,
|
||||
kept_var_idx=set(range(len(global_in_avals))),
|
||||
out_mut=None,
|
||||
backend=backend,
|
||||
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
||||
committed=True,
|
||||
@ -2297,7 +2296,6 @@ class MeshComputation(stages.XlaLowering):
|
||||
|
||||
def __init__(self, name: str, hlo: ir.Module | None,
|
||||
donated_invars: Sequence[bool], **compile_args):
|
||||
compile_args.setdefault('out_mut', None) # TODO(mattjj): remove default
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
self._donated_invars = donated_invars
|
||||
@ -2763,7 +2761,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
kept_var_idx: set[int]
|
||||
out_mut: Sequence[None | int]
|
||||
out_mut: Sequence[None | int] | None
|
||||
auto_spmd_lowering: bool
|
||||
in_layouts: Sequence[SpecifiedLayout | None]
|
||||
out_layouts: Sequence[SpecifiedLayout | None]
|
||||
@ -2802,7 +2800,7 @@ class UnloadedMeshExecutable:
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
|
||||
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
|
||||
UnspecifiedValue)],
|
||||
UnspecifiedValue)],
|
||||
spmd_lowering: bool,
|
||||
tuple_args: bool,
|
||||
auto_spmd_lowering: bool,
|
||||
@ -2811,13 +2809,13 @@ class UnloadedMeshExecutable:
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
kept_var_idx: set[int],
|
||||
out_mut: Sequence[None | int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
|
||||
committed: bool,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
pmap_nreps: int = 1,
|
||||
out_mut: Sequence[None | int] | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
all_default_mem_kind: bool = True,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user