mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Thread out_avals to MeshExecutable
PiperOrigin-RevId: 612037684
This commit is contained in:
parent
8569b893b1
commit
0b70244b1c
@ -2789,7 +2789,7 @@ class UnloadedMeshExecutable:
|
||||
|
||||
def load(self) -> MeshExecutable:
|
||||
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
|
||||
self.input_avals,
|
||||
self.input_avals, self.output_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.in_layouts, self.out_layouts,
|
||||
@ -2942,12 +2942,13 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_in_layouts", "_out_layouts", "_all_args_info", "_unloaded_executable",
|
||||
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
|
||||
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
|
||||
"_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
in_layouts, out_layouts,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
unloaded_executable=None):
|
||||
@ -2956,6 +2957,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
# is a GDA or jax.Array else local.
|
||||
self.in_avals = in_avals
|
||||
self.out_avals = out_avals
|
||||
self._unsafe_call = None
|
||||
self._in_shardings = in_shardings
|
||||
self._out_shardings = out_shardings
|
||||
@ -3118,8 +3120,9 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
committed=committed, pmap_nreps=pmap_nreps)
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, (None,) * len(global_in_avals),
|
||||
global_out_avals, in_shardings, out_shardings,
|
||||
auto_spmd_lowering, kept_var_idx,
|
||||
(None,) * len(global_in_avals),
|
||||
(None,) * len(global_out_avals))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user