mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
output res forwarding optimization for shard_map and jit
This commit is contained in:
parent
6aff74e7ff
commit
0944010186
@ -834,6 +834,37 @@ def trace_to_subjaxpr_nounits_fwd(
|
|||||||
del out_tracers
|
del out_tracers
|
||||||
yield jaxpr, (fwds, out_pvals, pruned_consts, env)
|
yield jaxpr, (fwds, out_pvals, pruned_consts, env)
|
||||||
|
|
||||||
|
# The below variant implements two optimizations:
|
||||||
|
# 1. residuals that are also primal inputs are indicated in aux data rather
|
||||||
|
# than passed as outputs;
|
||||||
|
# 2. residuals that are also primal outputs are indicated in aux data rather
|
||||||
|
# than passed as redundant outputs.
|
||||||
|
@lu.transformation
|
||||||
|
def trace_to_subjaxpr_nounits_fwd2(
|
||||||
|
main: core.MainTrace,
|
||||||
|
instantiate: bool | Sequence[bool],
|
||||||
|
in_pvals: Sequence[PartialVal]):
|
||||||
|
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||||
|
out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||||
|
main, instantiate, in_pvals)
|
||||||
|
out_pvals = [t.pval for t in out_tracers]
|
||||||
|
|
||||||
|
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||||
|
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||||
|
id_map = {id(c): i for i, c in enumerate(in_consts)}
|
||||||
|
input_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
|
||||||
|
|
||||||
|
# Which consts (aka residuals) are already primal outputs? Check obj id.
|
||||||
|
out_consts = [pval.get_known() for pval in out_pvals if pval.is_known()]
|
||||||
|
id_map = {id(c): i for i, c in enumerate(out_consts)}
|
||||||
|
output_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
|
||||||
|
|
||||||
|
pruned_consts = [c for c, f1, f2 in zip(consts, input_fwds, output_fwds)
|
||||||
|
if f1 is None and f2 is None]
|
||||||
|
|
||||||
|
del out_tracers
|
||||||
|
yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
|
||||||
|
|
||||||
|
|
||||||
FreeVar = namedtuple('FreeVar', ['val'])
|
FreeVar = namedtuple('FreeVar', ['val'])
|
||||||
ConstVar = namedtuple('ConstVar', ['val'])
|
ConstVar = namedtuple('ConstVar', ['val'])
|
||||||
@ -1359,48 +1390,72 @@ def call_partial_eval_custom_rule(
|
|||||||
if type(x) is Var and not inst]
|
if type(x) is Var and not inst]
|
||||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||||
|
|
||||||
|
# TODO(mattjj): unify with ParamsUpdater (this one takes an extra int)
|
||||||
|
ParamsUpdater2 = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
|
||||||
|
Sequence[bool], int, int, dict, dict],
|
||||||
|
tuple[dict, dict]]
|
||||||
|
|
||||||
def closed_call_partial_eval_custom_rule(
|
def closed_call_partial_eval_custom_rule(
|
||||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
jaxpr_param_name: str, params_updater: ParamsUpdater2,
|
||||||
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
|
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
|
||||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \
|
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
|
||||||
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
|
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
|
||||||
False, False, saveable)
|
False, False, saveable)
|
||||||
num_res = num_res_ref + num_res_out
|
num_res = num_res_ref + num_res_val
|
||||||
|
|
||||||
|
# Compute which residual value outputs are also *undropped* primal outputs.
|
||||||
|
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
|
||||||
|
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
|
||||||
|
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||||
|
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
|
||||||
|
if type(b) is not DropVar}
|
||||||
|
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||||
|
|
||||||
|
# Prune jaxpr_known_ outputs by removing forwards.
|
||||||
|
jaxpr_known_ = prune_jaxpr_outputs(
|
||||||
|
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
|
||||||
|
|
||||||
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
|
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
|
||||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
|
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
|
||||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
|
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
|
||||||
|
|
||||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
|
||||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||||
newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr])
|
newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr])
|
||||||
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
|
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
|
||||||
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
|
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
|
||||||
params_known, params_staged = params_updater(
|
params_known, params_staged = params_updater(
|
||||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
|
unks_in, inst_in, map(op.not_, unks_out), inst_out,
|
||||||
params_staged)
|
sum(f is None for f in out_fwd), num_res, params_known, params_staged)
|
||||||
residuals, ref_residuals = split_list(
|
res_val_binders, res_ref_binders = split_list(
|
||||||
[newvar(res_aval(params_known, v)) for v
|
[newvar(res_aval(params_known, v))
|
||||||
in jaxpr_staged.in_avals[:num_res]], [num_res_out])
|
for v in jaxpr_staged.in_avals[:num_res]], [num_res_val])
|
||||||
eqn_known = new_jaxpr_eqn([*ins_known, *ref_residuals],
|
res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None]
|
||||||
[*out_binders_known, *residuals],
|
res_val_binders_ = iter(res_val_binders)
|
||||||
|
res_val_vars = [out_binders_known[f] if f is not None
|
||||||
|
else next(res_val_binders_) for f in out_fwd]
|
||||||
|
sentinel = object()
|
||||||
|
assert next(res_val_binders_, sentinel) is sentinel
|
||||||
|
eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders],
|
||||||
|
[*out_binders_known, *res_val_binders],
|
||||||
eqn.primitive, params_known, jaxpr_known.effects,
|
eqn.primitive, params_known, jaxpr_known.effects,
|
||||||
eqn.source_info)
|
eqn.source_info)
|
||||||
eqn_staged = new_jaxpr_eqn([*residuals, *ref_residuals, *ins_staged],
|
eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged],
|
||||||
out_binders_staged,
|
out_binders_staged,
|
||||||
eqn.primitive, params_staged, jaxpr_staged.effects,
|
eqn.primitive, params_staged, jaxpr_staged.effects,
|
||||||
eqn.source_info)
|
eqn.source_info)
|
||||||
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
|
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
|
||||||
assert len(ins_known) + len(ref_residuals) == len(jaxpr_known.jaxpr.invars)
|
assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars)
|
||||||
assert len(ins_staged) + len(ref_residuals) + len(residuals) == len(jaxpr_staged.jaxpr.invars)
|
assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars)
|
||||||
assert len(out_binders_known) + len(residuals) == len(jaxpr_known.jaxpr.outvars)
|
assert len(out_binders_known) + len(res_val_binders) == len(jaxpr_known.jaxpr.outvars)
|
||||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||||
if type(x) is Var and not inst]
|
if type(x) is Var and not inst]
|
||||||
new_vars = [*new_inst, *residuals, *ref_residuals]
|
new_vars = [*new_inst, *res_val_vars, *res_ref_binders]
|
||||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||||
|
|
||||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||||
@ -1408,7 +1463,7 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
|
|||||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||||
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
||||||
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr',
|
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr',
|
||||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
lambda _, __, ___, ____, _____, ______, x, y: (x, y))
|
||||||
|
|
||||||
|
|
||||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
|
def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
|
||||||
@ -1427,6 +1482,33 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]:
|
|||||||
for v in jaxpr.outvars]
|
for v in jaxpr.outvars]
|
||||||
|
|
||||||
|
|
||||||
|
def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
|
||||||
|
return _prune_jaxpr_outputs_cached(jaxpr, tuple(used_outputs))
|
||||||
|
|
||||||
|
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
|
||||||
|
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
|
||||||
|
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
|
||||||
|
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||||
|
jaxpr.debug_info.arg_names,
|
||||||
|
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||||
|
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
|
||||||
|
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||||
|
return new_jaxpr
|
||||||
|
_prune_jaxpr_outputs_cached = weakref_lru_cache(_prune_jaxpr_outputs)
|
||||||
|
|
||||||
|
def prune_closed_jaxpr_outputs(
|
||||||
|
jaxpr: ClosedJaxpr, used_outputs: Sequence[bool]
|
||||||
|
) -> ClosedJaxpr:
|
||||||
|
return _prune_closed_jaxpr_outputs(jaxpr, tuple(used_outputs))
|
||||||
|
|
||||||
|
@weakref_lru_cache
|
||||||
|
def _prune_closed_jaxpr_outputs(
|
||||||
|
jaxpr: ClosedJaxpr, used_outputs: tuple[bool, ...]
|
||||||
|
) -> ClosedJaxpr:
|
||||||
|
return ClosedJaxpr(_prune_jaxpr_outputs(jaxpr.jaxpr, used_outputs),
|
||||||
|
jaxpr.consts)
|
||||||
|
|
||||||
|
|
||||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||||
instantiate: bool | Sequence[bool] = False,
|
instantiate: bool | Sequence[bool] = False,
|
||||||
) -> tuple[Jaxpr, list[bool]]:
|
) -> tuple[Jaxpr, list[bool]]:
|
||||||
|
@ -1514,9 +1514,9 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
|||||||
|
|
||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
||||||
fwds_known: tuple[Optional[int]]) -> core.ClosedJaxpr:
|
in_fwd: tuple[Optional[int]]) -> core.ClosedJaxpr:
|
||||||
updated_jaxpr = known_jaxpr.jaxpr.replace(
|
updated_jaxpr = known_jaxpr.jaxpr.replace(
|
||||||
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known)
|
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd)
|
||||||
if i is None])
|
if i is None])
|
||||||
return known_jaxpr.replace(jaxpr=updated_jaxpr)
|
return known_jaxpr.replace(jaxpr=updated_jaxpr)
|
||||||
|
|
||||||
@ -1533,59 +1533,55 @@ def _pjit_partial_eval(trace, *in_tracers,
|
|||||||
unknown_outs = tuple(unknown_outs)
|
unknown_outs = tuple(unknown_outs)
|
||||||
known_outs = tuple(not uk for uk in unknown_outs)
|
known_outs = tuple(not uk for uk in unknown_outs)
|
||||||
num_residuals = len(res_avals)
|
num_residuals = len(res_avals)
|
||||||
|
res_shardings = (UNSPECIFIED,) * num_residuals
|
||||||
|
|
||||||
def keep_where(l, should_keep):
|
def keep_where(l, should_keep):
|
||||||
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep)
|
return tuple(x for x, keep in zip(l, should_keep) if keep)
|
||||||
|
|
||||||
|
# Compute which outputs are just forwarded inputs.
|
||||||
|
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
|
||||||
|
in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
|
||||||
|
|
||||||
|
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
|
||||||
|
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
|
||||||
|
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
|
||||||
|
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
|
||||||
|
] + in_fwd_res
|
||||||
|
del in_fwd_primal, in_fwd_res
|
||||||
|
|
||||||
|
# Compute which residuals are just primal outputs.
|
||||||
|
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
|
||||||
|
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
||||||
|
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
|
||||||
|
|
||||||
|
# Prune jaxpr outputs and out_shardings by removing forwards.
|
||||||
|
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
|
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
|
||||||
|
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
|
||||||
|
known_out_shardings = keep_where(known_out_shardings, keep)
|
||||||
|
del keep, num_out_primals
|
||||||
|
|
||||||
residual_shardings = (UNSPECIFIED,) * num_residuals
|
|
||||||
# Compute the known outputs
|
|
||||||
known_params = dict(
|
known_params = dict(
|
||||||
jaxpr=known_jaxpr,
|
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
|
||||||
in_shardings=keep_where(in_shardings, known_ins),
|
out_shardings=known_out_shardings, resource_env=resource_env,
|
||||||
out_shardings=(
|
|
||||||
keep_where(out_shardings, known_outs) + residual_shardings),
|
|
||||||
resource_env=resource_env,
|
|
||||||
donated_invars=keep_where(donated_invars, known_ins),
|
donated_invars=keep_where(donated_invars, known_ins),
|
||||||
name=name,
|
name=name, keep_unused=keep_unused, inline=inline)
|
||||||
keep_unused=keep_unused,
|
|
||||||
inline=inline)
|
|
||||||
|
|
||||||
fwds_known = pe._jaxpr_forwarding(known_params['jaxpr'].jaxpr)
|
|
||||||
|
|
||||||
# Only forward the outvars where the out_sharding is UNSPECIFIED.
|
|
||||||
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
|
|
||||||
fwds_known_user = [
|
|
||||||
fwd if is_unspecified(os) else None
|
|
||||||
for os, fwd in zip(known_user_out_shardings,
|
|
||||||
fwds_known[:len(known_user_out_shardings)])]
|
|
||||||
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
|
|
||||||
del fwds_known_user
|
|
||||||
|
|
||||||
# Remove forwarded outvars and out_shardings
|
|
||||||
known_params['jaxpr'] = _known_jaxpr_fwd(known_params['jaxpr'], tuple(fwds_known))
|
|
||||||
known_out_shardings = tuple(
|
|
||||||
s for s, i in zip(known_params['out_shardings'], fwds_known) if i is None)
|
|
||||||
known_params['out_shardings'] = known_out_shardings
|
|
||||||
del known_out_shardings
|
|
||||||
|
|
||||||
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
||||||
|
|
||||||
# Bind known things to pjit_p.
|
# Bind known things to pjit_p.
|
||||||
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
|
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
|
||||||
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
|
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
|
||||||
|
|
||||||
known_outs_iter = iter(all_known_outs)
|
all_known_outs_ = iter(all_known_outs)
|
||||||
all_known_outs = [next(known_outs_iter)
|
all_known_outs = [known_inputs[f1] if f1 is not None else
|
||||||
if fwd_idx is None else known_inputs[fwd_idx]
|
all_known_outs[f2] if f2 is not None else
|
||||||
for fwd_idx in fwds_known]
|
next(all_known_outs_) for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
assert next(known_outs_iter, None) is None
|
sentinel = object()
|
||||||
del known_outs_iter, known_inputs
|
assert next(all_known_outs_, sentinel) is sentinel
|
||||||
|
del all_known_outs_, known_inputs
|
||||||
|
|
||||||
if num_residuals:
|
known_out_vals, residual_vals = \
|
||||||
known_out_vals, residual_vals = \
|
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
|
||||||
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
|
|
||||||
else:
|
|
||||||
known_out_vals, residual_vals = all_known_outs, ()
|
|
||||||
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
|
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
|
||||||
|
|
||||||
# The convention of partial_eval_jaxpr_nounits is to place residual binders
|
# The convention of partial_eval_jaxpr_nounits is to place residual binders
|
||||||
@ -1597,7 +1593,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
|||||||
# Prepare unknown tracers
|
# Prepare unknown tracers
|
||||||
unknown_params = dict(
|
unknown_params = dict(
|
||||||
jaxpr=unknown_jaxpr,
|
jaxpr=unknown_jaxpr,
|
||||||
in_shardings=(keep_where(in_shardings, unknown_ins) + residual_shardings),
|
in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings),
|
||||||
out_shardings=keep_where(out_shardings, unknown_outs),
|
out_shardings=keep_where(out_shardings, unknown_outs),
|
||||||
resource_env=resource_env,
|
resource_env=resource_env,
|
||||||
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
||||||
@ -1626,28 +1622,25 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
|||||||
def _pjit_partial_eval_custom_params_updater(
|
def _pjit_partial_eval_custom_params_updater(
|
||||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||||
num_res: int, params_known: dict, params_staged: dict
|
num_res_out: int, num_res_in: int, params_known: dict, params_staged: dict
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
# prune inputs to jaxpr_known according to unks_in
|
# prune inputs to jaxpr_known according to unks_in
|
||||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||||
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
||||||
if num_res == 0:
|
|
||||||
residual_shardings = []
|
|
||||||
else:
|
|
||||||
residual_shardings = [UNSPECIFIED] * num_res
|
|
||||||
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
||||||
new_params_known = dict(params_known,
|
new_params_known = dict(params_known,
|
||||||
in_shardings=tuple(in_shardings_known),
|
in_shardings=tuple(in_shardings_known),
|
||||||
out_shardings=(*out_shardings_known, *residual_shardings),
|
out_shardings=(*out_shardings_known,
|
||||||
|
*[UNSPECIFIED] * num_res_out),
|
||||||
donated_invars=tuple(donated_invars_known))
|
donated_invars=tuple(donated_invars_known))
|
||||||
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
||||||
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
||||||
|
|
||||||
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
||||||
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
||||||
donated_invars_staged = [False] * num_res + donated_invars_staged
|
donated_invars_staged = [False] * num_res_in + donated_invars_staged
|
||||||
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
||||||
in_shardings_staged = [*residual_shardings, *in_shardings_staged]
|
in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
|
||||||
|
|
||||||
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
||||||
|
|
||||||
|
@ -1262,30 +1262,37 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
|||||||
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
||||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||||
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
||||||
f = pe.trace_to_subjaxpr_nounits(f, trace.main, False)
|
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False)
|
||||||
f = _promote_scalar_residuals(f)
|
f = _promote_scalar_residuals(f)
|
||||||
f_known, aux = pe.partial_eval_wrapper_nounits(
|
f_known, aux = pe.partial_eval_wrapper_nounits(
|
||||||
f, (*in_knowns,), (*in_avals_sharded,))
|
f, (*in_knowns,), (*in_avals_sharded,))
|
||||||
|
|
||||||
@as_hashable_function(closure=out_names_thunk)
|
@as_hashable_function(closure=out_names_thunk)
|
||||||
def known_out_names():
|
def known_out_names():
|
||||||
out_knowns, _, jaxpr, _ = aux()
|
in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux()
|
||||||
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
|
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
|
||||||
assert not any(not v.aval.shape for v in jaxpr.constvars)
|
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||||
res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
|
return (*out_known_names, *({0: (*mesh.axis_names,)},) * num_res)
|
||||||
return (*out_known_names, *res_names)
|
|
||||||
|
|
||||||
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
||||||
out_names_thunk=known_out_names, check_rep=check_rep,
|
out_names_thunk=known_out_names, check_rep=check_rep,
|
||||||
rewrite=rewrite, auto=auto)
|
rewrite=rewrite, auto=auto)
|
||||||
out = shard_map_p.bind(f_known, *in_consts, **known_params)
|
out = shard_map_p.bind(f_known, *in_consts, **known_params)
|
||||||
out_knowns, out_avals_sharded, jaxpr, env = aux()
|
in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
|
||||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
assert not jaxpr.constvars
|
||||||
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
|
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
|
||||||
unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
known_out_names_ = known_out_names()
|
||||||
+ (*unk_in_names,))
|
non_fwd_res_ = iter(non_fwd_res)
|
||||||
|
res = [in_consts[f1] if f1 is not None else out_consts[f2] if f2 is not None
|
||||||
|
else next(non_fwd_res_) for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
|
sentinel = object()
|
||||||
|
assert next(non_fwd_res_, sentinel) is sentinel
|
||||||
|
res_names = [known_in_names[f1] if f1 is not None else
|
||||||
|
known_out_names_[f2] if f2 is not None else
|
||||||
|
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
|
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
|
||||||
const_tracers = map(trace.new_instantiated_const, res)
|
const_tracers = map(trace.new_instantiated_const, res)
|
||||||
env_tracers = map(trace.full_raise, env)
|
env_tracers = map(trace.full_raise, env)
|
||||||
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
||||||
@ -1307,7 +1314,11 @@ def _shard_map_partial_eval_post_process(
|
|||||||
del check_rep
|
del check_rep
|
||||||
unk_tracers = [t for t in tracers if not t.is_known()]
|
unk_tracers = [t for t in tracers if not t.is_known()]
|
||||||
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
|
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
|
||||||
jaxpr, res = _promote_scalar_residuals_jaxpr(jaxpr, res)
|
# TODO(mattjj): output forwarding optimization
|
||||||
|
which = [not v.aval.shape for v in jaxpr.constvars]
|
||||||
|
res = [jax.lax.broadcast(x, (1,)) if not v.aval.shape else x
|
||||||
|
for x, v in zip(res, jaxpr.constvars)]
|
||||||
|
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
||||||
|
|
||||||
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
|
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
|
||||||
out = [*consts, *res]
|
out = [*consts, *res]
|
||||||
@ -1317,11 +1328,11 @@ def _shard_map_partial_eval_post_process(
|
|||||||
|
|
||||||
def todo(out):
|
def todo(out):
|
||||||
trace = main.with_cur_sublevel()
|
trace = main.with_cur_sublevel()
|
||||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
out_consts, res_ = split_list(out, [len(out) - len(res)])
|
||||||
const_tracers = map(trace.new_instantiated_const, res)
|
const_tracers = map(trace.new_instantiated_const, res_)
|
||||||
env_tracers = map(trace.full_raise, env)
|
env_tracers = map(trace.full_raise, env)
|
||||||
|
|
||||||
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res_) + ({},) * len(env)
|
||||||
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
|
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
|
||||||
out_names=(*out_names_unknown,), check_rep=False,
|
out_names=(*out_names_unknown,), check_rep=False,
|
||||||
rewrite=rewrite, auto=auto)
|
rewrite=rewrite, auto=auto)
|
||||||
@ -1339,7 +1350,7 @@ def _shard_map_partial_eval_post_process(
|
|||||||
def out_names_transform(out_names):
|
def out_names_transform(out_names):
|
||||||
nonlocal out_names_unknown
|
nonlocal out_names_unknown
|
||||||
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
|
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
|
||||||
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
|
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(res)
|
||||||
out_names_unknown: list | None = None
|
out_names_unknown: list | None = None
|
||||||
|
|
||||||
return out, (todo, out_names_transform)
|
return out, (todo, out_names_transform)
|
||||||
@ -1347,21 +1358,25 @@ pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
|
|||||||
|
|
||||||
@lu.transformation
|
@lu.transformation
|
||||||
def _promote_scalar_residuals(*args, **kwargs):
|
def _promote_scalar_residuals(*args, **kwargs):
|
||||||
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
|
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs
|
||||||
jaxpr, out_consts = _promote_scalar_residuals_jaxpr(jaxpr, out_consts)
|
which = [f1 is None and f2 is None and not v.aval.shape
|
||||||
yield jaxpr, (out_pvals, out_consts, env)
|
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
|
||||||
|
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
||||||
def _promote_scalar_residuals_jaxpr(jaxpr, res):
|
out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||||
which = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
|
for x in out_consts]
|
||||||
for v in jaxpr.constvars]
|
yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
|
||||||
res_ = [jax.lax.broadcast(x, (1,)) if s else x for x, s in zip(res, which)]
|
|
||||||
|
|
||||||
|
def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||||
@lu.wrap_init
|
@lu.wrap_init
|
||||||
def fun(*args):
|
def fun(*res_and_args):
|
||||||
res = [_rem_singleton(x) if s else x for x, s in zip(res_, which)]
|
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
|
||||||
|
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
|
||||||
return core.eval_jaxpr(jaxpr, res, *args)
|
return core.eval_jaxpr(jaxpr, res, *args)
|
||||||
jaxpr, _, res = pe.trace_to_jaxpr_dynamic(fun, [v.aval for v in jaxpr.invars])
|
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||||
return jaxpr, res
|
for v, w in zip(jaxpr.constvars, which)]
|
||||||
|
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
|
||||||
|
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||||
|
return jaxpr
|
||||||
|
|
||||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||||
check_rep, rewrite, auto):
|
check_rep, rewrite, auto):
|
||||||
@ -1429,66 +1444,89 @@ def _partial_eval_jaxpr_custom_rule(
|
|||||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||||
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
||||||
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
|
num_out_primals = len(jaxpr_known.outvars) - num_res
|
||||||
|
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
|
||||||
|
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
|
||||||
|
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
||||||
|
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||||
|
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
|
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
|
||||||
|
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||||
|
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||||
params_known, params_staged = _pe_custom_params(
|
params_known, params_staged = _pe_custom_params(
|
||||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res,
|
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
|
||||||
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
||||||
residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval))
|
residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval))
|
||||||
for var in jaxpr_staged.invars[:num_res]]
|
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
|
||||||
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||||
eqn.primitive, params_known, jaxpr_known.effects,
|
eqn.primitive, params_known, jaxpr_known.effects,
|
||||||
eqn.source_info)
|
eqn.source_info)
|
||||||
eqn_staged = pe.new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
residuals_ = iter(residuals)
|
||||||
|
full_res = [ins_known[f1] if f1 is not None else
|
||||||
|
out_binders_known[f2] if f2 is not None else
|
||||||
|
next(residuals_) for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
|
sentinel = object()
|
||||||
|
assert next(residuals_, sentinel) is sentinel
|
||||||
|
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
|
||||||
eqn.primitive, params_staged,
|
eqn.primitive, params_staged,
|
||||||
jaxpr_staged.effects, eqn.source_info)
|
jaxpr_staged.effects, eqn.source_info)
|
||||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||||
if type(x) is core.Var and not inst]
|
if type(x) is core.Var and not inst]
|
||||||
|
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
|
||||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||||
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
||||||
_partial_eval_jaxpr_custom_rule
|
_partial_eval_jaxpr_custom_rule
|
||||||
|
|
||||||
def _add_reshapes(num_res, jaxpr_known, jaxpr_staged):
|
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||||
if not num_res: return jaxpr_known, jaxpr_staged
|
# add singleton axes to residuals which are from jaxpr_known and are scalars
|
||||||
|
which_ = [w and not v.aval.shape
|
||||||
|
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
|
||||||
|
if not any(which_): return jaxpr_known, jaxpr_staged
|
||||||
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
||||||
|
|
||||||
@lu.wrap_init
|
@lu.wrap_init
|
||||||
def known(*args):
|
def known(*args):
|
||||||
out = core.eval_jaxpr(jaxpr_known, (), *args)
|
out = core.eval_jaxpr(jaxpr_known, (), *args)
|
||||||
out_known, res = split_list(out, [len(out) - num_res])
|
out_known, res = split_list(out, [len(out) - sum(which)])
|
||||||
return [*out_known, *map(_add_singleton, res)]
|
res = [_add_singleton(x) if not x.shape else x for x in res]
|
||||||
|
return [*out_known, *res]
|
||||||
avals_in = [v.aval for v in jaxpr_known.invars]
|
avals_in = [v.aval for v in jaxpr_known.invars]
|
||||||
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
|
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
|
||||||
|
|
||||||
@lu.wrap_init
|
@lu.wrap_init
|
||||||
def staged(*args):
|
def staged(*args):
|
||||||
res_, ins = split_list(args, [num_res])
|
res_, ins = split_list(args, [len(which)])
|
||||||
res = map(_rem_singleton, res_)
|
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
|
||||||
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
|
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
|
||||||
res_avals = [v.aval for v in jaxpr_known.outvars[-num_res:]]
|
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||||
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[num_res:]]]
|
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
||||||
|
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
|
||||||
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
||||||
|
|
||||||
return jaxpr_known, jaxpr_staged
|
return jaxpr_known, jaxpr_staged
|
||||||
|
|
||||||
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||||
num_res, params_known, params_staged):
|
in_fwd, out_fwd, which, params_known, params_staged):
|
||||||
# prune inputs to jaxpr_known according to unks_in
|
# prune inputs to jaxpr_known according to unks_in
|
||||||
mesh = params_known['mesh']
|
mesh = params_known['mesh']
|
||||||
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
||||||
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
||||||
out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * num_res
|
out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * sum(which)
|
||||||
new_params_known = dict(params_known, in_names=tuple(in_names_known),
|
new_params_known = dict(params_known, in_names=tuple(in_names_known),
|
||||||
out_names=tuple(out_names_known))
|
out_names=tuple(out_names_known))
|
||||||
|
|
||||||
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
||||||
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
|
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
|
||||||
in_names_staged = [{0: (*mesh.axis_names,)}] * num_res + in_names_staged
|
res_names = [in_names_known[f1] if f1 is not None else
|
||||||
|
out_names_known[f2] if f2 is not None else
|
||||||
|
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||||
|
in_names_staged = res_names + in_names_staged
|
||||||
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
|
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
|
||||||
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
|
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
|
||||||
out_names=tuple(out_names_staged), check_rep=False)
|
out_names=tuple(out_names_staged), check_rep=False)
|
||||||
|
@ -27,6 +27,7 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
import jax.ad_checkpoint
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.sharding import Mesh
|
from jax.sharding import Mesh
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
@ -1135,6 +1136,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
|
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
def test_axis_env_extension_regression(self):
|
def test_axis_env_extension_regression(self):
|
||||||
def foo(x):
|
def foo(x):
|
||||||
i = jax.lax.axis_index('x')
|
i = jax.lax.axis_index('x')
|
||||||
@ -1147,6 +1149,51 @@ class ShardMapTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash
|
jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash
|
||||||
|
|
||||||
|
@parameterized.parameters(it.product([True, False], repeat=2))
|
||||||
|
def test_res_forwarding_optimization(self, jit, remat):
|
||||||
|
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||||
|
|
||||||
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||||
|
def f(x):
|
||||||
|
return jax.lax.exp(x)
|
||||||
|
if jit:
|
||||||
|
f = jax.jit(f)
|
||||||
|
if remat:
|
||||||
|
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
|
||||||
|
f = jax.remat(f, policy=policy)
|
||||||
|
g = lambda x: f(x).sum()
|
||||||
|
|
||||||
|
x = jnp.arange(16.)
|
||||||
|
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
|
||||||
|
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
|
||||||
|
e1, _, e2 = jaxpr.eqns
|
||||||
|
self.assertLen(e1.outvars, 1) # only primal output
|
||||||
|
self.assertLen(e2.invars, 2) # res and cotangent inputs
|
||||||
|
self.assertEqual(sum([e1.outvars[0] is v for v in e2.invars]), 1)
|
||||||
|
|
||||||
|
@parameterized.parameters(it.product([True, False], repeat=2))
|
||||||
|
def test_res_forwarding_optimization_complex(self, jit, remat):
|
||||||
|
# like the above test, but a different function `f`
|
||||||
|
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||||
|
|
||||||
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||||
|
def f(x):
|
||||||
|
return jax.lax.exp(x.sum()) + x, jax.lax.exp(x)
|
||||||
|
if jit:
|
||||||
|
f = jax.jit(f)
|
||||||
|
if remat:
|
||||||
|
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
|
||||||
|
f = jax.remat(f, policy=policy)
|
||||||
|
g = lambda x: sum(f(x)).sum()
|
||||||
|
|
||||||
|
x = jnp.arange(16.)
|
||||||
|
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
|
||||||
|
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
|
||||||
|
e1, _, e2 = jaxpr.eqns
|
||||||
|
self.assertLen(e1.outvars, 2) # one primal and one res output
|
||||||
|
self.assertLen(e2.invars, 4) # two res and two cotangent inputs
|
||||||
|
self.assertEqual(sum([e1.outvars[-1] is v for v in e2.invars]), 1)
|
||||||
|
|
||||||
|
|
||||||
class FunSpec(NamedTuple):
|
class FunSpec(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
|
Loading…
x
Reference in New Issue
Block a user