output res forwarding optimization for shard_map and jit

This commit is contained in:
Matthew Johnson 2023-10-12 16:00:08 -07:00
parent 6aff74e7ff
commit 0944010186
4 changed files with 272 additions and 112 deletions

View File

@ -834,6 +834,37 @@ def trace_to_subjaxpr_nounits_fwd(
del out_tracers del out_tracers
yield jaxpr, (fwds, out_pvals, pruned_consts, env) yield jaxpr, (fwds, out_pvals, pruned_consts, env)
# The below variant implements two optimizations:
# 1. residuals that are also primal inputs are indicated in aux data rather
# than passed as outputs;
# 2. residuals that are also primal outputs are indicated in aux data rather
# than passed as redundant outputs.
@lu.transformation
def trace_to_subjaxpr_nounits_fwd2(
main: core.MainTrace,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits(
main, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
id_map = {id(c): i for i, c in enumerate(in_consts)}
input_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
# Which consts (aka residuals) are already primal outputs? Check obj id.
out_consts = [pval.get_known() for pval in out_pvals if pval.is_known()]
id_map = {id(c): i for i, c in enumerate(out_consts)}
output_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
pruned_consts = [c for c, f1, f2 in zip(consts, input_fwds, output_fwds)
if f1 is None and f2 is None]
del out_tracers
yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
FreeVar = namedtuple('FreeVar', ['val']) FreeVar = namedtuple('FreeVar', ['val'])
ConstVar = namedtuple('ConstVar', ['val']) ConstVar = namedtuple('ConstVar', ['val'])
@ -1359,48 +1390,72 @@ def call_partial_eval_custom_rule(
if type(x) is Var and not inst] if type(x) is Var and not inst]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
# TODO(mattjj): unify with ParamsUpdater (this one takes an extra int)
ParamsUpdater2 = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
Sequence[bool], int, int, dict, dict],
tuple[dict, dict]]
def closed_call_partial_eval_custom_rule( def closed_call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater, jaxpr_param_name: str, params_updater: ParamsUpdater2,
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule. # TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
closed_jaxpr = eqn.params[jaxpr_param_name] closed_jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \ jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in, partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
False, False, saveable) False, False, saveable)
num_res = num_res_ref + num_res_out num_res = num_res_ref + num_res_val
# Compute which residual value outputs are also *undropped* primal outputs.
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
if type(b) is not DropVar}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
# Prune jaxpr_known_ outputs by removing forwards.
jaxpr_known_ = prune_jaxpr_outputs(
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching # Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts) jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts) jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
ins_known, _ = partition_list(unks_in, eqn.invars) ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars) _, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr]) newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr])
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
params_known, params_staged = params_updater( params_known, params_staged = params_updater(
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, unks_in, inst_in, map(op.not_, unks_out), inst_out,
params_staged) sum(f is None for f in out_fwd), num_res, params_known, params_staged)
residuals, ref_residuals = split_list( res_val_binders, res_ref_binders = split_list(
[newvar(res_aval(params_known, v)) for v [newvar(res_aval(params_known, v))
in jaxpr_staged.in_avals[:num_res]], [num_res_out]) for v in jaxpr_staged.in_avals[:num_res]], [num_res_val])
eqn_known = new_jaxpr_eqn([*ins_known, *ref_residuals], res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None]
[*out_binders_known, *residuals], res_val_binders_ = iter(res_val_binders)
res_val_vars = [out_binders_known[f] if f is not None
else next(res_val_binders_) for f in out_fwd]
sentinel = object()
assert next(res_val_binders_, sentinel) is sentinel
eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders],
[*out_binders_known, *res_val_binders],
eqn.primitive, params_known, jaxpr_known.effects, eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info) eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *ref_residuals, *ins_staged], eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged],
out_binders_staged, out_binders_staged,
eqn.primitive, params_staged, jaxpr_staged.effects, eqn.primitive, params_staged, jaxpr_staged.effects,
eqn.source_info) eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals) assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
assert len(ins_known) + len(ref_residuals) == len(jaxpr_known.jaxpr.invars) assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars)
assert len(ins_staged) + len(ref_residuals) + len(residuals) == len(jaxpr_staged.jaxpr.invars) assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars)
assert len(out_binders_known) + len(residuals) == len(jaxpr_known.jaxpr.outvars) assert len(out_binders_known) + len(res_val_binders) == len(jaxpr_known.jaxpr.outvars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in) new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is Var and not inst] if type(x) is Var and not inst]
new_vars = [*new_inst, *residuals, *ref_residuals] new_vars = [*new_inst, *res_val_vars, *res_ref_binders]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars return eqn_known, eqn_staged, unks_out, inst_out, new_vars
partial_eval_jaxpr_custom_rules[core.call_p] = \ partial_eval_jaxpr_custom_rules[core.call_p] = \
@ -1408,7 +1463,7 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
lambda _, __, ___, ____, _____, x, y: (x, y)) lambda _, __, ___, ____, _____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \ partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr', partial(closed_call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, x, y: (x, y)) lambda _, __, ___, ____, _____, ______, x, y: (x, y))
def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]: def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
@ -1427,6 +1482,33 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
for v in jaxpr.outvars] for v in jaxpr.outvars]
def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
return _prune_jaxpr_outputs_cached(jaxpr, tuple(used_outputs))
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
_prune_jaxpr_outputs_cached = weakref_lru_cache(_prune_jaxpr_outputs)
def prune_closed_jaxpr_outputs(
jaxpr: ClosedJaxpr, used_outputs: Sequence[bool]
) -> ClosedJaxpr:
return _prune_closed_jaxpr_outputs(jaxpr, tuple(used_outputs))
@weakref_lru_cache
def _prune_closed_jaxpr_outputs(
jaxpr: ClosedJaxpr, used_outputs: tuple[bool, ...]
) -> ClosedJaxpr:
return ClosedJaxpr(_prune_jaxpr_outputs(jaxpr.jaxpr, used_outputs),
jaxpr.consts)
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: bool | Sequence[bool] = False, instantiate: bool | Sequence[bool] = False,
) -> tuple[Jaxpr, list[bool]]: ) -> tuple[Jaxpr, list[bool]]:

View File

@ -1514,9 +1514,9 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
@weakref_lru_cache @weakref_lru_cache
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
fwds_known: tuple[Optional[int]]) -> core.ClosedJaxpr: in_fwd: tuple[Optional[int]]) -> core.ClosedJaxpr:
updated_jaxpr = known_jaxpr.jaxpr.replace( updated_jaxpr = known_jaxpr.jaxpr.replace(
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known) outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd)
if i is None]) if i is None])
return known_jaxpr.replace(jaxpr=updated_jaxpr) return known_jaxpr.replace(jaxpr=updated_jaxpr)
@ -1533,59 +1533,55 @@ def _pjit_partial_eval(trace, *in_tracers,
unknown_outs = tuple(unknown_outs) unknown_outs = tuple(unknown_outs)
known_outs = tuple(not uk for uk in unknown_outs) known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals) num_residuals = len(res_avals)
res_shardings = (UNSPECIFIED,) * num_residuals
def keep_where(l, should_keep): def keep_where(l, should_keep):
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep) return tuple(x for x, keep in zip(l, should_keep) if keep)
# Compute which outputs are just forwarded inputs.
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
] + in_fwd_res
del in_fwd_primal, in_fwd_res
# Compute which residuals are just primal outputs.
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
# Prune jaxpr outputs and out_shardings by removing forwards.
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
known_out_shardings = keep_where(known_out_shardings, keep)
del keep, num_out_primals
residual_shardings = (UNSPECIFIED,) * num_residuals
# Compute the known outputs
known_params = dict( known_params = dict(
jaxpr=known_jaxpr, jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
in_shardings=keep_where(in_shardings, known_ins), out_shardings=known_out_shardings, resource_env=resource_env,
out_shardings=(
keep_where(out_shardings, known_outs) + residual_shardings),
resource_env=resource_env,
donated_invars=keep_where(donated_invars, known_ins), donated_invars=keep_where(donated_invars, known_ins),
name=name, name=name, keep_unused=keep_unused, inline=inline)
keep_unused=keep_unused,
inline=inline)
fwds_known = pe._jaxpr_forwarding(known_params['jaxpr'].jaxpr)
# Only forward the outvars where the out_sharding is UNSPECIFIED.
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
fwds_known_user = [
fwd if is_unspecified(os) else None
for os, fwd in zip(known_user_out_shardings,
fwds_known[:len(known_user_out_shardings)])]
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
del fwds_known_user
# Remove forwarded outvars and out_shardings
known_params['jaxpr'] = _known_jaxpr_fwd(known_params['jaxpr'], tuple(fwds_known))
known_out_shardings = tuple(
s for s, i in zip(known_params['out_shardings'], fwds_known) if i is None)
known_params['out_shardings'] = known_out_shardings
del known_out_shardings
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
# Bind known things to pjit_p. # Bind known things to pjit_p.
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()] known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
all_known_outs = pjit_p.bind(*known_inputs, **known_params) all_known_outs = pjit_p.bind(*known_inputs, **known_params)
known_outs_iter = iter(all_known_outs) all_known_outs_ = iter(all_known_outs)
all_known_outs = [next(known_outs_iter) all_known_outs = [known_inputs[f1] if f1 is not None else
if fwd_idx is None else known_inputs[fwd_idx] all_known_outs[f2] if f2 is not None else
for fwd_idx in fwds_known] next(all_known_outs_) for f1, f2 in zip(in_fwd, out_fwd)]
assert next(known_outs_iter, None) is None sentinel = object()
del known_outs_iter, known_inputs assert next(all_known_outs_, sentinel) is sentinel
del all_known_outs_, known_inputs
if num_residuals: known_out_vals, residual_vals = \
known_out_vals, residual_vals = \ split_list(all_known_outs, [len(all_known_outs) - num_residuals])
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
else:
known_out_vals, residual_vals = all_known_outs, ()
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals] residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
# The convention of partial_eval_jaxpr_nounits is to place residual binders # The convention of partial_eval_jaxpr_nounits is to place residual binders
@ -1597,7 +1593,7 @@ def _pjit_partial_eval(trace, *in_tracers,
# Prepare unknown tracers # Prepare unknown tracers
unknown_params = dict( unknown_params = dict(
jaxpr=unknown_jaxpr, jaxpr=unknown_jaxpr,
in_shardings=(keep_where(in_shardings, unknown_ins) + residual_shardings), in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings),
out_shardings=keep_where(out_shardings, unknown_outs), out_shardings=keep_where(out_shardings, unknown_outs),
resource_env=resource_env, resource_env=resource_env,
donated_invars=(keep_where(donated_invars, unknown_ins) + donated_invars=(keep_where(donated_invars, unknown_ins) +
@ -1626,28 +1622,25 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
def _pjit_partial_eval_custom_params_updater( def _pjit_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool], unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool], kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict num_res_out: int, num_res_in: int, params_known: dict, params_staged: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# prune inputs to jaxpr_known according to unks_in # prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars']) donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings']) in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
if num_res == 0:
residual_shardings = []
else:
residual_shardings = [UNSPECIFIED] * num_res
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings']) _, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
new_params_known = dict(params_known, new_params_known = dict(params_known,
in_shardings=tuple(in_shardings_known), in_shardings=tuple(in_shardings_known),
out_shardings=(*out_shardings_known, *residual_shardings), out_shardings=(*out_shardings_known,
*[UNSPECIFIED] * num_res_out),
donated_invars=tuple(donated_invars_known)) donated_invars=tuple(donated_invars_known))
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals) assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals) assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in # added num_res new inputs to jaxpr_staged, and pruning according to inst_in
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars']) _, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res + donated_invars_staged donated_invars_staged = [False] * num_res_in + donated_invars_staged
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings']) _, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
in_shardings_staged = [*residual_shardings, *in_shardings_staged] in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings']) _, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])

View File

@ -1262,30 +1262,37 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits(f, trace.main, False) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False)
f = _promote_scalar_residuals(f) f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits( f_known, aux = pe.partial_eval_wrapper_nounits(
f, (*in_knowns,), (*in_avals_sharded,)) f, (*in_knowns,), (*in_avals_sharded,))
@as_hashable_function(closure=out_names_thunk) @as_hashable_function(closure=out_names_thunk)
def known_out_names(): def known_out_names():
out_knowns, _, jaxpr, _ = aux() in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux()
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) _, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
assert not any(not v.aval.shape for v in jaxpr.constvars) num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars) return (*out_known_names, *({0: (*mesh.axis_names,)},) * num_res)
return (*out_known_names, *res_names)
known_params = dict(mesh=mesh, in_names=(*known_in_names,), known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep, out_names_thunk=known_out_names, check_rep=check_rep,
rewrite=rewrite, auto=auto) rewrite=rewrite, auto=auto)
out = shard_map_p.bind(f_known, *in_consts, **known_params) out = shard_map_p.bind(f_known, *in_consts, **known_params)
out_knowns, out_avals_sharded, jaxpr, env = aux() in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
with core.extend_axis_env_nd(mesh.shape.items()): out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
jaxpr = pe.convert_constvars_jaxpr(jaxpr) assert not jaxpr.constvars
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env) known_out_names_ = known_out_names()
+ (*unk_in_names,)) non_fwd_res_ = iter(non_fwd_res)
res = [in_consts[f1] if f1 is not None else out_consts[f2] if f2 is not None
else next(non_fwd_res_) for f1, f2 in zip(in_fwd, out_fwd)]
sentinel = object()
assert next(non_fwd_res_, sentinel) is sentinel
res_names = [known_in_names[f1] if f1 is not None else
known_out_names_[f2] if f2 is not None else
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
const_tracers = map(trace.new_instantiated_const, res) const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env) env_tracers = map(trace.full_raise, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()] unk_arg_tracers = [t for t in tracers if not t.is_known()]
@ -1307,7 +1314,11 @@ def _shard_map_partial_eval_post_process(
del check_rep del check_rep
unk_tracers = [t for t in tracers if not t.is_known()] unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers) jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
jaxpr, res = _promote_scalar_residuals_jaxpr(jaxpr, res) # TODO(mattjj): output forwarding optimization
which = [not v.aval.shape for v in jaxpr.constvars]
res = [jax.lax.broadcast(x, (1,)) if not v.aval.shape else x
for x, v in zip(res, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers]) out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
out = [*consts, *res] out = [*consts, *res]
@ -1317,11 +1328,11 @@ def _shard_map_partial_eval_post_process(
def todo(out): def todo(out):
trace = main.with_cur_sublevel() trace = main.with_cur_sublevel()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) out_consts, res_ = split_list(out, [len(out) - len(res)])
const_tracers = map(trace.new_instantiated_const, res) const_tracers = map(trace.new_instantiated_const, res_)
env_tracers = map(trace.full_raise, env) env_tracers = map(trace.full_raise, env)
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env) staged_in_names = ({0: (*mesh.axis_names,)},) * len(res_) + ({},) * len(env)
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
out_names=(*out_names_unknown,), check_rep=False, out_names=(*out_names_unknown,), check_rep=False,
rewrite=rewrite, auto=auto) rewrite=rewrite, auto=auto)
@ -1339,7 +1350,7 @@ def _shard_map_partial_eval_post_process(
def out_names_transform(out_names): def out_names_transform(out_names):
nonlocal out_names_unknown nonlocal out_names_unknown
out_names_unknown, out_names_known = partition_list(out_knowns, out_names) out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars) return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(res)
out_names_unknown: list | None = None out_names_unknown: list | None = None
return out, (todo, out_names_transform) return out, (todo, out_names_transform)
@ -1347,21 +1358,25 @@ pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
@lu.transformation @lu.transformation
def _promote_scalar_residuals(*args, **kwargs): def _promote_scalar_residuals(*args, **kwargs):
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs
jaxpr, out_consts = _promote_scalar_residuals_jaxpr(jaxpr, out_consts) which = [f1 is None and f2 is None and not v.aval.shape
yield jaxpr, (out_pvals, out_consts, env) for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
def _promote_scalar_residuals_jaxpr(jaxpr, res): out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
which = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape for x in out_consts]
for v in jaxpr.constvars] yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
res_ = [jax.lax.broadcast(x, (1,)) if s else x for x, s in zip(res, which)]
def _promote_scalar_residuals_jaxpr(jaxpr, which):
@lu.wrap_init @lu.wrap_init
def fun(*args): def fun(*res_and_args):
res = [_rem_singleton(x) if s else x for x, s in zip(res_, which)] res, args = split_list(res_and_args, [len(jaxpr.constvars)])
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
return core.eval_jaxpr(jaxpr, res, *args) return core.eval_jaxpr(jaxpr, res, *args)
jaxpr, _, res = pe.trace_to_jaxpr_dynamic(fun, [v.aval for v in jaxpr.invars]) res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
return jaxpr, res for v, w in zip(jaxpr.constvars, which)]
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto): check_rep, rewrite, auto):
@ -1429,66 +1444,89 @@ def _partial_eval_jaxpr_custom_rule(
with core.extend_axis_env_nd(mesh.shape.items()): with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged) num_out_primals = len(jaxpr_known.outvars) - num_res
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
ins_known, _ = partition_list(unks_in, eqn.invars) ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars) out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars) _, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged]) newvar = core.gensym([jaxpr_known, jaxpr_staged])
params_known, params_staged = _pe_custom_params( params_known, params_staged = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval)) residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval))
for var in jaxpr_staged.invars[:num_res]] for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects, eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info) eqn.source_info)
eqn_staged = pe.new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, residuals_ = iter(residuals)
full_res = [ins_known[f1] if f1 is not None else
out_binders_known[f2] if f2 is not None else
next(residuals_) for f1, f2 in zip(in_fwd, out_fwd)]
sentinel = object()
assert next(residuals_, sentinel) is sentinel
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
eqn.primitive, params_staged, eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info) jaxpr_staged.effects, eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars) assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in) new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst] if type(x) is core.Var and not inst]
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
_partial_eval_jaxpr_custom_rule _partial_eval_jaxpr_custom_rule
def _add_reshapes(num_res, jaxpr_known, jaxpr_staged): def _add_reshapes(which, jaxpr_known, jaxpr_staged):
if not num_res: return jaxpr_known, jaxpr_staged # add singleton axes to residuals which are from jaxpr_known and are scalars
which_ = [w and not v.aval.shape
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
if not any(which_): return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars assert not jaxpr_known.constvars and not jaxpr_staged.constvars
@lu.wrap_init @lu.wrap_init
def known(*args): def known(*args):
out = core.eval_jaxpr(jaxpr_known, (), *args) out = core.eval_jaxpr(jaxpr_known, (), *args)
out_known, res = split_list(out, [len(out) - num_res]) out_known, res = split_list(out, [len(out) - sum(which)])
return [*out_known, *map(_add_singleton, res)] res = [_add_singleton(x) if not x.shape else x for x in res]
return [*out_known, *res]
avals_in = [v.aval for v in jaxpr_known.invars] avals_in = [v.aval for v in jaxpr_known.invars]
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in) jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
@lu.wrap_init @lu.wrap_init
def staged(*args): def staged(*args):
res_, ins = split_list(args, [num_res]) res_, ins = split_list(args, [len(which)])
res = map(_rem_singleton, res_) res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
res_avals = [v.aval for v in jaxpr_known.outvars[-num_res:]] res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[num_res:]]] for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in) jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
return jaxpr_known, jaxpr_staged return jaxpr_known, jaxpr_staged
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
num_res, params_known, params_staged): in_fwd, out_fwd, which, params_known, params_staged):
# prune inputs to jaxpr_known according to unks_in # prune inputs to jaxpr_known according to unks_in
mesh = params_known['mesh'] mesh = params_known['mesh']
in_names_known, _ = partition_list(unks_in, params_known['in_names']) in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * num_res out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * sum(which)
new_params_known = dict(params_known, in_names=tuple(in_names_known), new_params_known = dict(params_known, in_names=tuple(in_names_known),
out_names=tuple(out_names_known)) out_names=tuple(out_names_known))
# added num_res new inputs to jaxpr_staged, pruning according to inst_in # added num_res new inputs to jaxpr_staged, pruning according to inst_in
_, in_names_staged = partition_list(inst_in, params_staged['in_names']) _, in_names_staged = partition_list(inst_in, params_staged['in_names'])
in_names_staged = [{0: (*mesh.axis_names,)}] * num_res + in_names_staged res_names = [in_names_known[f1] if f1 is not None else
out_names_known[f2] if f2 is not None else
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
in_names_staged = res_names + in_names_staged
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
out_names=tuple(out_names_staged), check_rep=False) out_names=tuple(out_names_staged), check_rep=False)

View File

@ -27,6 +27,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import jax import jax
import jax.ad_checkpoint
from jax import lax from jax import lax
from jax.sharding import Mesh from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
@ -1135,6 +1136,7 @@ class ShardMapTest(jtu.JaxTestCase):
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2) jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
def test_axis_env_extension_regression(self): def test_axis_env_extension_regression(self):
def foo(x): def foo(x):
i = jax.lax.axis_index('x') i = jax.lax.axis_index('x')
@ -1147,6 +1149,51 @@ class ShardMapTest(jtu.JaxTestCase):
jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash
@parameterized.parameters(it.product([True, False], repeat=2))
def test_res_forwarding_optimization(self, jit, remat):
mesh = jtu.create_global_mesh((4,), ('i',))
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(x):
return jax.lax.exp(x)
if jit:
f = jax.jit(f)
if remat:
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
f = jax.remat(f, policy=policy)
g = lambda x: f(x).sum()
x = jnp.arange(16.)
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
e1, _, e2 = jaxpr.eqns
self.assertLen(e1.outvars, 1) # only primal output
self.assertLen(e2.invars, 2) # res and cotangent inputs
self.assertEqual(sum([e1.outvars[0] is v for v in e2.invars]), 1)
@parameterized.parameters(it.product([True, False], repeat=2))
def test_res_forwarding_optimization_complex(self, jit, remat):
# like the above test, but a different function `f`
mesh = jtu.create_global_mesh((4,), ('i',))
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(x):
return jax.lax.exp(x.sum()) + x, jax.lax.exp(x)
if jit:
f = jax.jit(f)
if remat:
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
f = jax.remat(f, policy=policy)
g = lambda x: sum(f(x)).sum()
x = jnp.arange(16.)
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
e1, _, e2 = jaxpr.eqns
self.assertLen(e1.outvars, 2) # one primal and one res output
self.assertLen(e2.invars, 4) # two res and two cotangent inputs
self.assertEqual(sum([e1.outvars[-1] is v for v in e2.invars]), 1)
class FunSpec(NamedTuple): class FunSpec(NamedTuple):
name: str name: str