Thread out_avals to MeshExecutable

PiperOrigin-RevId: 612037684
This commit is contained in:
Yash Katariya 2024-03-02 13:34:46 -08:00 committed by jax authors
parent 8569b893b1
commit 0b70244b1c

View File

@ -2789,7 +2789,7 @@ class UnloadedMeshExecutable:
def load(self) -> MeshExecutable:
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
self.input_avals,
self.input_avals, self.output_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.in_layouts, self.out_layouts,
@ -2942,12 +2942,13 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat):
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_in_layouts", "_out_layouts", "_all_args_info", "_unloaded_executable",
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
"_unloaded_executable",
]
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
in_layouts, out_layouts,
all_args_info: AllArgsInfo | None = None,
unloaded_executable=None):
@ -2956,6 +2957,7 @@ class MeshExecutable(stages.XlaExecutable):
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self.out_avals = out_avals
self._unsafe_call = None
self._in_shardings = in_shardings
self._out_shardings = out_shardings
@ -3118,8 +3120,9 @@ def _compile_replicated_mesh_executable_from_hlo(
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, (None,) * len(global_in_avals),
global_out_avals, in_shardings, out_shardings,
auto_spmd_lowering, kept_var_idx,
(None,) * len(global_in_avals),
(None,) * len(global_out_avals))