mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
output res forwarding optimization for shard_map and jit
This commit is contained in:
parent
6aff74e7ff
commit
0944010186
@ -834,6 +834,37 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
del out_tracers
|
||||
yield jaxpr, (fwds, out_pvals, pruned_consts, env)
|
||||
|
||||
# The below variant implements two optimizations:
|
||||
# 1. residuals that are also primal inputs are indicated in aux data rather
|
||||
# than passed as outputs;
|
||||
# 2. residuals that are also primal outputs are indicated in aux data rather
|
||||
# than passed as redundant outputs.
|
||||
@lu.transformation
|
||||
def trace_to_subjaxpr_nounits_fwd2(
|
||||
main: core.MainTrace,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
main, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
|
||||
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
id_map = {id(c): i for i, c in enumerate(in_consts)}
|
||||
input_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
|
||||
|
||||
# Which consts (aka residuals) are already primal outputs? Check obj id.
|
||||
out_consts = [pval.get_known() for pval in out_pvals if pval.is_known()]
|
||||
id_map = {id(c): i for i, c in enumerate(out_consts)}
|
||||
output_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
|
||||
|
||||
pruned_consts = [c for c, f1, f2 in zip(consts, input_fwds, output_fwds)
|
||||
if f1 is None and f2 is None]
|
||||
|
||||
del out_tracers
|
||||
yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
|
||||
|
||||
|
||||
FreeVar = namedtuple('FreeVar', ['val'])
|
||||
ConstVar = namedtuple('ConstVar', ['val'])
|
||||
@ -1359,48 +1390,72 @@ def call_partial_eval_custom_rule(
|
||||
if type(x) is Var and not inst]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
|
||||
# TODO(mattjj): unify with ParamsUpdater (this one takes an extra int)
|
||||
ParamsUpdater2 = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
|
||||
Sequence[bool], int, int, dict, dict],
|
||||
tuple[dict, dict]]
|
||||
|
||||
def closed_call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater2,
|
||||
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
|
||||
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
|
||||
False, False, saveable)
|
||||
num_res = num_res_ref + num_res_out
|
||||
num_res = num_res_ref + num_res_val
|
||||
|
||||
# Compute which residual value outputs are also *undropped* primal outputs.
|
||||
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
|
||||
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
|
||||
if type(b) is not DropVar}
|
||||
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||
|
||||
# Prune jaxpr_known_ outputs by removing forwards.
|
||||
jaxpr_known_ = prune_jaxpr_outputs(
|
||||
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
|
||||
|
||||
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
|
||||
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr])
|
||||
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
|
||||
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
|
||||
params_known, params_staged = params_updater(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
|
||||
params_staged)
|
||||
residuals, ref_residuals = split_list(
|
||||
[newvar(res_aval(params_known, v)) for v
|
||||
in jaxpr_staged.in_avals[:num_res]], [num_res_out])
|
||||
eqn_known = new_jaxpr_eqn([*ins_known, *ref_residuals],
|
||||
[*out_binders_known, *residuals],
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out,
|
||||
sum(f is None for f in out_fwd), num_res, params_known, params_staged)
|
||||
res_val_binders, res_ref_binders = split_list(
|
||||
[newvar(res_aval(params_known, v))
|
||||
for v in jaxpr_staged.in_avals[:num_res]], [num_res_val])
|
||||
res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None]
|
||||
res_val_binders_ = iter(res_val_binders)
|
||||
res_val_vars = [out_binders_known[f] if f is not None
|
||||
else next(res_val_binders_) for f in out_fwd]
|
||||
sentinel = object()
|
||||
assert next(res_val_binders_, sentinel) is sentinel
|
||||
eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders],
|
||||
[*out_binders_known, *res_val_binders],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ref_residuals, *ins_staged],
|
||||
eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged],
|
||||
out_binders_staged,
|
||||
eqn.primitive, params_staged, jaxpr_staged.effects,
|
||||
eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
|
||||
assert len(ins_known) + len(ref_residuals) == len(jaxpr_known.jaxpr.invars)
|
||||
assert len(ins_staged) + len(ref_residuals) + len(residuals) == len(jaxpr_staged.jaxpr.invars)
|
||||
assert len(out_binders_known) + len(residuals) == len(jaxpr_known.jaxpr.outvars)
|
||||
assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars)
|
||||
assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars)
|
||||
assert len(out_binders_known) + len(res_val_binders) == len(jaxpr_known.jaxpr.outvars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
new_vars = [*new_inst, *residuals, *ref_residuals]
|
||||
new_vars = [*new_inst, *res_val_vars, *res_ref_binders]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||
|
||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
@ -1408,7 +1463,7 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
||||
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
lambda _, __, ___, ____, _____, ______, x, y: (x, y))
|
||||
|
||||
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
|
||||
@ -1427,6 +1482,33 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
|
||||
for v in jaxpr.outvars]
|
||||
|
||||
|
||||
def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
|
||||
return _prune_jaxpr_outputs_cached(jaxpr, tuple(used_outputs))
|
||||
|
||||
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
|
||||
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
|
||||
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
jaxpr.debug_info.arg_names,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
return new_jaxpr
|
||||
_prune_jaxpr_outputs_cached = weakref_lru_cache(_prune_jaxpr_outputs)
|
||||
|
||||
def prune_closed_jaxpr_outputs(
|
||||
jaxpr: ClosedJaxpr, used_outputs: Sequence[bool]
|
||||
) -> ClosedJaxpr:
|
||||
return _prune_closed_jaxpr_outputs(jaxpr, tuple(used_outputs))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _prune_closed_jaxpr_outputs(
|
||||
jaxpr: ClosedJaxpr, used_outputs: tuple[bool, ...]
|
||||
) -> ClosedJaxpr:
|
||||
return ClosedJaxpr(_prune_jaxpr_outputs(jaxpr.jaxpr, used_outputs),
|
||||
jaxpr.consts)
|
||||
|
||||
|
||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
instantiate: bool | Sequence[bool] = False,
|
||||
) -> tuple[Jaxpr, list[bool]]:
|
||||
|
@ -1514,9 +1514,9 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
|
||||
@weakref_lru_cache
|
||||
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
||||
fwds_known: tuple[Optional[int]]) -> core.ClosedJaxpr:
|
||||
in_fwd: tuple[Optional[int]]) -> core.ClosedJaxpr:
|
||||
updated_jaxpr = known_jaxpr.jaxpr.replace(
|
||||
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known)
|
||||
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd)
|
||||
if i is None])
|
||||
return known_jaxpr.replace(jaxpr=updated_jaxpr)
|
||||
|
||||
@ -1533,59 +1533,55 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
unknown_outs = tuple(unknown_outs)
|
||||
known_outs = tuple(not uk for uk in unknown_outs)
|
||||
num_residuals = len(res_avals)
|
||||
res_shardings = (UNSPECIFIED,) * num_residuals
|
||||
|
||||
def keep_where(l, should_keep):
|
||||
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep)
|
||||
return tuple(x for x, keep in zip(l, should_keep) if keep)
|
||||
|
||||
# Compute which outputs are just forwarded inputs.
|
||||
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
|
||||
in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
|
||||
|
||||
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
|
||||
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
|
||||
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
|
||||
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
|
||||
] + in_fwd_res
|
||||
del in_fwd_primal, in_fwd_res
|
||||
|
||||
# Compute which residuals are just primal outputs.
|
||||
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
|
||||
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
||||
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
|
||||
|
||||
# Prune jaxpr outputs and out_shardings by removing forwards.
|
||||
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
|
||||
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
|
||||
known_out_shardings = keep_where(known_out_shardings, keep)
|
||||
del keep, num_out_primals
|
||||
|
||||
residual_shardings = (UNSPECIFIED,) * num_residuals
|
||||
# Compute the known outputs
|
||||
known_params = dict(
|
||||
jaxpr=known_jaxpr,
|
||||
in_shardings=keep_where(in_shardings, known_ins),
|
||||
out_shardings=(
|
||||
keep_where(out_shardings, known_outs) + residual_shardings),
|
||||
resource_env=resource_env,
|
||||
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
|
||||
out_shardings=known_out_shardings, resource_env=resource_env,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
fwds_known = pe._jaxpr_forwarding(known_params['jaxpr'].jaxpr)
|
||||
|
||||
# Only forward the outvars where the out_sharding is UNSPECIFIED.
|
||||
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
|
||||
fwds_known_user = [
|
||||
fwd if is_unspecified(os) else None
|
||||
for os, fwd in zip(known_user_out_shardings,
|
||||
fwds_known[:len(known_user_out_shardings)])]
|
||||
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
|
||||
del fwds_known_user
|
||||
|
||||
# Remove forwarded outvars and out_shardings
|
||||
known_params['jaxpr'] = _known_jaxpr_fwd(known_params['jaxpr'], tuple(fwds_known))
|
||||
known_out_shardings = tuple(
|
||||
s for s, i in zip(known_params['out_shardings'], fwds_known) if i is None)
|
||||
known_params['out_shardings'] = known_out_shardings
|
||||
del known_out_shardings
|
||||
|
||||
name=name, keep_unused=keep_unused, inline=inline)
|
||||
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
||||
|
||||
# Bind known things to pjit_p.
|
||||
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
|
||||
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
|
||||
|
||||
known_outs_iter = iter(all_known_outs)
|
||||
all_known_outs = [next(known_outs_iter)
|
||||
if fwd_idx is None else known_inputs[fwd_idx]
|
||||
for fwd_idx in fwds_known]
|
||||
assert next(known_outs_iter, None) is None
|
||||
del known_outs_iter, known_inputs
|
||||
all_known_outs_ = iter(all_known_outs)
|
||||
all_known_outs = [known_inputs[f1] if f1 is not None else
|
||||
all_known_outs[f2] if f2 is not None else
|
||||
next(all_known_outs_) for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
sentinel = object()
|
||||
assert next(all_known_outs_, sentinel) is sentinel
|
||||
del all_known_outs_, known_inputs
|
||||
|
||||
if num_residuals:
|
||||
known_out_vals, residual_vals = \
|
||||
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
|
||||
else:
|
||||
known_out_vals, residual_vals = all_known_outs, ()
|
||||
known_out_vals, residual_vals = \
|
||||
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
|
||||
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
|
||||
|
||||
# The convention of partial_eval_jaxpr_nounits is to place residual binders
|
||||
@ -1597,7 +1593,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
# Prepare unknown tracers
|
||||
unknown_params = dict(
|
||||
jaxpr=unknown_jaxpr,
|
||||
in_shardings=(keep_where(in_shardings, unknown_ins) + residual_shardings),
|
||||
in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings),
|
||||
out_shardings=keep_where(out_shardings, unknown_outs),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
||||
@ -1626,28 +1622,25 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
||||
def _pjit_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
num_res_out: int, num_res_in: int, params_known: dict, params_staged: dict
|
||||
) -> tuple[dict, dict]:
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
||||
if num_res == 0:
|
||||
residual_shardings = []
|
||||
else:
|
||||
residual_shardings = [UNSPECIFIED] * num_res
|
||||
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
||||
new_params_known = dict(params_known,
|
||||
in_shardings=tuple(in_shardings_known),
|
||||
out_shardings=(*out_shardings_known, *residual_shardings),
|
||||
out_shardings=(*out_shardings_known,
|
||||
*[UNSPECIFIED] * num_res_out),
|
||||
donated_invars=tuple(donated_invars_known))
|
||||
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
||||
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
||||
|
||||
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
||||
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
||||
donated_invars_staged = [False] * num_res + donated_invars_staged
|
||||
donated_invars_staged = [False] * num_res_in + donated_invars_staged
|
||||
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
||||
in_shardings_staged = [*residual_shardings, *in_shardings_staged]
|
||||
in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
|
||||
|
||||
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
||||
|
||||
|
@ -1262,30 +1262,37 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
||||
f = pe.trace_to_subjaxpr_nounits(f, trace.main, False)
|
||||
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False)
|
||||
f = _promote_scalar_residuals(f)
|
||||
f_known, aux = pe.partial_eval_wrapper_nounits(
|
||||
f, (*in_knowns,), (*in_avals_sharded,))
|
||||
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def known_out_names():
|
||||
out_knowns, _, jaxpr, _ = aux()
|
||||
in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux()
|
||||
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
|
||||
assert not any(not v.aval.shape for v in jaxpr.constvars)
|
||||
res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
|
||||
return (*out_known_names, *res_names)
|
||||
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
return (*out_known_names, *({0: (*mesh.axis_names,)},) * num_res)
|
||||
|
||||
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
||||
out_names_thunk=known_out_names, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
out = shard_map_p.bind(f_known, *in_consts, **known_params)
|
||||
out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
|
||||
assert not jaxpr.constvars
|
||||
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
|
||||
unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
||||
+ (*unk_in_names,))
|
||||
known_out_names_ = known_out_names()
|
||||
non_fwd_res_ = iter(non_fwd_res)
|
||||
res = [in_consts[f1] if f1 is not None else out_consts[f2] if f2 is not None
|
||||
else next(non_fwd_res_) for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
sentinel = object()
|
||||
assert next(non_fwd_res_, sentinel) is sentinel
|
||||
res_names = [known_in_names[f1] if f1 is not None else
|
||||
known_out_names_[f2] if f2 is not None else
|
||||
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
||||
@ -1307,7 +1314,11 @@ def _shard_map_partial_eval_post_process(
|
||||
del check_rep
|
||||
unk_tracers = [t for t in tracers if not t.is_known()]
|
||||
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
|
||||
jaxpr, res = _promote_scalar_residuals_jaxpr(jaxpr, res)
|
||||
# TODO(mattjj): output forwarding optimization
|
||||
which = [not v.aval.shape for v in jaxpr.constvars]
|
||||
res = [jax.lax.broadcast(x, (1,)) if not v.aval.shape else x
|
||||
for x, v in zip(res, jaxpr.constvars)]
|
||||
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
||||
|
||||
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
|
||||
out = [*consts, *res]
|
||||
@ -1317,11 +1328,11 @@ def _shard_map_partial_eval_post_process(
|
||||
|
||||
def todo(out):
|
||||
trace = main.with_cur_sublevel()
|
||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
out_consts, res_ = split_list(out, [len(out) - len(res)])
|
||||
const_tracers = map(trace.new_instantiated_const, res_)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
|
||||
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
||||
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res_) + ({},) * len(env)
|
||||
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
|
||||
out_names=(*out_names_unknown,), check_rep=False,
|
||||
rewrite=rewrite, auto=auto)
|
||||
@ -1339,7 +1350,7 @@ def _shard_map_partial_eval_post_process(
|
||||
def out_names_transform(out_names):
|
||||
nonlocal out_names_unknown
|
||||
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
|
||||
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
|
||||
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(res)
|
||||
out_names_unknown: list | None = None
|
||||
|
||||
return out, (todo, out_names_transform)
|
||||
@ -1347,21 +1358,25 @@ pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
|
||||
|
||||
@lu.transformation
|
||||
def _promote_scalar_residuals(*args, **kwargs):
|
||||
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
|
||||
jaxpr, out_consts = _promote_scalar_residuals_jaxpr(jaxpr, out_consts)
|
||||
yield jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
def _promote_scalar_residuals_jaxpr(jaxpr, res):
|
||||
which = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
|
||||
for v in jaxpr.constvars]
|
||||
res_ = [jax.lax.broadcast(x, (1,)) if s else x for x, s in zip(res, which)]
|
||||
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs
|
||||
which = [f1 is None and f2 is None and not v.aval.shape
|
||||
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
|
||||
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
||||
out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||
for x in out_consts]
|
||||
yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
|
||||
|
||||
def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
@lu.wrap_init
|
||||
def fun(*args):
|
||||
res = [_rem_singleton(x) if s else x for x, s in zip(res_, which)]
|
||||
def fun(*res_and_args):
|
||||
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
|
||||
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
|
||||
return core.eval_jaxpr(jaxpr, res, *args)
|
||||
jaxpr, _, res = pe.trace_to_jaxpr_dynamic(fun, [v.aval for v in jaxpr.invars])
|
||||
return jaxpr, res
|
||||
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||
for v, w in zip(jaxpr.constvars, which)]
|
||||
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return jaxpr
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
@ -1429,66 +1444,89 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
|
||||
num_out_primals = len(jaxpr_known.outvars) - num_res
|
||||
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
|
||||
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
|
||||
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
||||
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
|
||||
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||
params_known, params_staged = _pe_custom_params(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res,
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
|
||||
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval))
|
||||
for var in jaxpr_staged.invars[:num_res]]
|
||||
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
|
||||
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info)
|
||||
eqn_staged = pe.new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
||||
residuals_ = iter(residuals)
|
||||
full_res = [ins_known[f1] if f1 is not None else
|
||||
out_binders_known[f2] if f2 is not None else
|
||||
next(residuals_) for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
sentinel = object()
|
||||
assert next(residuals_, sentinel) is sentinel
|
||||
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
||||
_partial_eval_jaxpr_custom_rule
|
||||
|
||||
def _add_reshapes(num_res, jaxpr_known, jaxpr_staged):
|
||||
if not num_res: return jaxpr_known, jaxpr_staged
|
||||
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
# add singleton axes to residuals which are from jaxpr_known and are scalars
|
||||
which_ = [w and not v.aval.shape
|
||||
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
|
||||
if not any(which_): return jaxpr_known, jaxpr_staged
|
||||
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
||||
|
||||
@lu.wrap_init
|
||||
def known(*args):
|
||||
out = core.eval_jaxpr(jaxpr_known, (), *args)
|
||||
out_known, res = split_list(out, [len(out) - num_res])
|
||||
return [*out_known, *map(_add_singleton, res)]
|
||||
out_known, res = split_list(out, [len(out) - sum(which)])
|
||||
res = [_add_singleton(x) if not x.shape else x for x in res]
|
||||
return [*out_known, *res]
|
||||
avals_in = [v.aval for v in jaxpr_known.invars]
|
||||
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
|
||||
|
||||
@lu.wrap_init
|
||||
def staged(*args):
|
||||
res_, ins = split_list(args, [num_res])
|
||||
res = map(_rem_singleton, res_)
|
||||
res_, ins = split_list(args, [len(which)])
|
||||
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
|
||||
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
|
||||
res_avals = [v.aval for v in jaxpr_known.outvars[-num_res:]]
|
||||
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[num_res:]]]
|
||||
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
||||
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
|
||||
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
||||
|
||||
return jaxpr_known, jaxpr_staged
|
||||
|
||||
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
num_res, params_known, params_staged):
|
||||
in_fwd, out_fwd, which, params_known, params_staged):
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
mesh = params_known['mesh']
|
||||
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
||||
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
||||
out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * num_res
|
||||
out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * sum(which)
|
||||
new_params_known = dict(params_known, in_names=tuple(in_names_known),
|
||||
out_names=tuple(out_names_known))
|
||||
|
||||
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
||||
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
|
||||
in_names_staged = [{0: (*mesh.axis_names,)}] * num_res + in_names_staged
|
||||
res_names = [in_names_known[f1] if f1 is not None else
|
||||
out_names_known[f2] if f2 is not None else
|
||||
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
in_names_staged = res_names + in_names_staged
|
||||
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
|
||||
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
|
||||
out_names=tuple(out_names_staged), check_rep=False)
|
||||
|
@ -27,6 +27,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.ad_checkpoint
|
||||
from jax import lax
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -1135,6 +1136,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
|
||||
|
||||
|
||||
def test_axis_env_extension_regression(self):
|
||||
def foo(x):
|
||||
i = jax.lax.axis_index('x')
|
||||
@ -1147,6 +1149,51 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash
|
||||
|
||||
@parameterized.parameters(it.product([True, False], repeat=2))
|
||||
def test_res_forwarding_optimization(self, jit, remat):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(x):
|
||||
return jax.lax.exp(x)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
if remat:
|
||||
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
|
||||
f = jax.remat(f, policy=policy)
|
||||
g = lambda x: f(x).sum()
|
||||
|
||||
x = jnp.arange(16.)
|
||||
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
|
||||
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
|
||||
e1, _, e2 = jaxpr.eqns
|
||||
self.assertLen(e1.outvars, 1) # only primal output
|
||||
self.assertLen(e2.invars, 2) # res and cotangent inputs
|
||||
self.assertEqual(sum([e1.outvars[0] is v for v in e2.invars]), 1)
|
||||
|
||||
@parameterized.parameters(it.product([True, False], repeat=2))
|
||||
def test_res_forwarding_optimization_complex(self, jit, remat):
|
||||
# like the above test, but a different function `f`
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(x):
|
||||
return jax.lax.exp(x.sum()) + x, jax.lax.exp(x)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
if remat:
|
||||
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
|
||||
f = jax.remat(f, policy=policy)
|
||||
g = lambda x: sum(f(x)).sum()
|
||||
|
||||
x = jnp.arange(16.)
|
||||
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
|
||||
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
|
||||
e1, _, e2 = jaxpr.eqns
|
||||
self.assertLen(e1.outvars, 2) # one primal and one res output
|
||||
self.assertLen(e2.invars, 4) # two res and two cotangent inputs
|
||||
self.assertEqual(sum([e1.outvars[-1] is v for v in e2.invars]), 1)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user