Merge pull request #13958 from mattjj:pjit-partial-eval-2

PiperOrigin-RevId: 501319644
This commit is contained in:
jax authors 2023-01-11 10:36:39 -08:00
commit 7a6c75339f
3 changed files with 61 additions and 18 deletions

View File

@ -146,6 +146,7 @@ class ClosedJaxpr:
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
self.jaxpr = jaxpr
self.consts = list(consts)

View File

@ -1979,3 +1979,41 @@ def _get_pspec_from_executable(
out_partition_spec = _get_partition_spec(out_ppspec)
in_partition_spec = _get_partition_spec(in_ppspec)
return tuple(in_partition_spec), tuple(out_partition_spec)
def _pjit_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
if num_res == 0:
residual_shardings = []
else:
residual_shardings = [_UNSPECIFIED] * num_res
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
new_params_known = dict(params_known,
in_shardings=tuple(in_shardings_known),
out_shardings=(*out_shardings_known, *residual_shardings),
donated_invars=tuple(donated_invars_known))
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res + donated_invars_staged
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
in_shardings_staged = [*residual_shardings, *in_shardings_staged]
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
new_params_staged = dict(params_staged,
in_shardings=tuple(in_shardings_staged),
out_shardings=tuple(out_shardings_staged),
donated_invars=tuple(donated_invars_staged))
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr',
_pjit_partial_eval_custom_params_updater)

View File

@ -1257,34 +1257,37 @@ def call_partial_eval_custom_rule(
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
def closed_call_partial_eval_custom_rule(
jaxpr_param_name: str,
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
closed_jaxpr = eqn.params[jaxpr_param_name]
jaxpr = convert_constvars_jaxpr(closed_jaxpr.jaxpr)
unks_in = [False] * len(closed_jaxpr.consts) + list(unks_in)
inst_in = [False] * len(closed_jaxpr.consts) + list(inst_in)
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in,
False, False, saveable)
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged])
params_known = {**eqn.params, jaxpr_param_name: core.ClosedJaxpr(jaxpr_known,
())}
params_staged = {**eqn.params, jaxpr_param_name:
core.ClosedJaxpr(jaxpr_staged, ())}
residuals = [newvar(res_aval(params_known, var.aval))
for var in jaxpr_staged.invars[:num_res]]
newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr])
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
params_known, params_staged = params_updater(
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
params_staged)
residuals = [newvar(res_aval(params_known, a))
for a in jaxpr_staged.in_avals[:num_res]]
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
eqn.primitive, params_staged, jaxpr_staged.effects,
eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is Var and not inst]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
@ -1293,7 +1296,8 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr')
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, x, y: (x, y))
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]: