mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation
PiperOrigin-RevId: 628564679
This commit is contained in:
parent
d9b75350b7
commit
755f350910
@ -1732,30 +1732,28 @@ def prune_unused_inputs(
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
|
||||
def _dce_jaxpr(closed_jaxpr, api_name, fun_name,
|
||||
keep_unused, donated_invars, auto_spmd_lowering):
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
assert isinstance(closed_jaxpr, core.ClosedJaxpr)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
consts = closed_jaxpr.consts
|
||||
in_avals = closed_jaxpr.in_avals
|
||||
|
||||
if (keep_unused or auto_spmd_lowering or
|
||||
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
||||
for a in global_in_avals)):
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
for a in in_avals)):
|
||||
kept_var_idx = set(range(len(in_avals)))
|
||||
else:
|
||||
jaxpr, kept_const_idx, kept_var_idx = prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
|
||||
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
|
||||
del kept_const_idx
|
||||
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
|
||||
kept_var_idx, name_stack)
|
||||
return closed_jaxpr, donated_invars, kept_var_idx, name_stack
|
||||
|
||||
class MutationData(NamedTuple):
|
||||
in_mut: list[core.MutableArray]
|
||||
@ -2034,6 +2032,27 @@ def spmd_mode_check(da_object, inline):
|
||||
return
|
||||
|
||||
|
||||
def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
|
||||
donated_invars, out_shardings, out_layouts):
|
||||
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
|
||||
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
|
||||
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
|
||||
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
|
||||
out_layouts_ = iter(zip(out_shardings, out_layouts))
|
||||
out_shardings, out_layouts = unzip2(
|
||||
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
|
||||
for i in mut.out_mut)
|
||||
assert next(out_layouts_, None) is None
|
||||
else:
|
||||
inout_aliases = mut = None
|
||||
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
|
||||
|
||||
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
|
||||
donated_invars, out_shardings, out_layouts)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
@ -2044,7 +2063,6 @@ def lower_sharding_computation(
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
@ -2062,34 +2080,23 @@ def lower_sharding_computation(
|
||||
auto_spmd_lowering = check_if_any_auto(
|
||||
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, closed_jaxpr.jaxpr.debug_info)
|
||||
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
kept_var_idx, name_stack) = _dce_jaxpr(
|
||||
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
||||
donated_invars, auto_spmd_lowering)
|
||||
closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr(
|
||||
closed_jaxpr, api_name, fun_name, keep_unused, donated_invars,
|
||||
auto_spmd_lowering)
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
|
||||
|
||||
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
|
||||
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
|
||||
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
|
||||
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
|
||||
out_layouts_ = iter(zip(out_shardings, out_layouts))
|
||||
out_shardings, out_layouts = unzip2(
|
||||
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
|
||||
for i in mut.out_mut)
|
||||
assert next(out_layouts_, None) is None
|
||||
# TODO(yashkatariya): remove global_in_avals / global_out_avals
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
else:
|
||||
inout_aliases = mut = None
|
||||
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
|
||||
(closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
|
||||
donated_invars, out_shardings, out_layouts) = _discharge_refs_jaxpr(
|
||||
closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings,
|
||||
out_layouts)
|
||||
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
|
@ -715,7 +715,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
|
||||
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
|
||||
(None,) * len(in_avals), (None,) * len(out_avals),
|
||||
donated_invars, in_avals, keep_unused=True, inline=False,
|
||||
donated_invars, keep_unused=True, inline=False,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
|
@ -1593,7 +1593,7 @@ def _pjit_lower_cached(
|
||||
else:
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars), tuple(jaxpr.in_avals),
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user