Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation

PiperOrigin-RevId: 628564679
This commit is contained in:
Yash Katariya 2024-04-26 18:27:26 -07:00 committed by jax authors
parent d9b75350b7
commit 755f350910
3 changed files with 39 additions and 32 deletions

View File

@ -1732,30 +1732,28 @@ def prune_unused_inputs(
@weakref_lru_cache
def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
def _dce_jaxpr(closed_jaxpr, api_name, fun_name,
keep_unused, donated_invars, auto_spmd_lowering):
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
assert isinstance(closed_jaxpr, core.ClosedJaxpr)
jaxpr = closed_jaxpr.jaxpr
global_out_avals = closed_jaxpr.out_avals
consts = closed_jaxpr.consts
in_avals = closed_jaxpr.in_avals
if (keep_unused or auto_spmd_lowering or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
for a in in_avals)):
kept_var_idx = set(range(len(in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
kept_var_idx, name_stack)
return closed_jaxpr, donated_invars, kept_var_idx, name_stack
class MutationData(NamedTuple):
in_mut: list[core.MutableArray]
@ -2034,6 +2032,27 @@ def spmd_mode_check(da_object, inline):
return
def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts):
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
out_layouts_ = iter(zip(out_shardings, out_layouts))
out_shardings, out_layouts = unzip2(
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
for i in mut.out_mut)
assert next(out_layouts_, None) is None
else:
inout_aliases = mut = None
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts)
@profiler.annotate_function
def lower_sharding_computation(
closed_jaxpr: core.ClosedJaxpr,
@ -2044,7 +2063,6 @@ def lower_sharding_computation(
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
keep_unused: bool,
inline: bool,
@ -2062,34 +2080,23 @@ def lower_sharding_computation(
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
all_args_info = AllArgsInfo(global_in_avals, closed_jaxpr.jaxpr.debug_info)
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr.debug_info)
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _dce_jaxpr(
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars, auto_spmd_lowering)
closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr(
closed_jaxpr, api_name, fun_name, keep_unused, donated_invars,
auto_spmd_lowering)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
out_layouts_ = iter(zip(out_shardings, out_layouts))
out_shardings, out_layouts = unzip2(
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
for i in mut.out_mut)
assert next(out_layouts_, None) is None
# TODO(yashkatariya): remove global_in_avals / global_out_avals
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
else:
inout_aliases = mut = None
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
(closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts) = _discharge_refs_jaxpr(
closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings,
out_layouts)
jaxpr = closed_jaxpr.jaxpr
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))

View File

@ -715,7 +715,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
(None,) * len(in_avals), (None,) * len(out_avals),
donated_invars, in_avals, keep_unused=True, inline=False,
donated_invars, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters)

View File

@ -1593,7 +1593,7 @@ def _pjit_lower_cached(
else:
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars), tuple(jaxpr.in_avals),
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused, inline=inline,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),