mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13958 from mattjj:pjit-partial-eval-2
PiperOrigin-RevId: 501319644
This commit is contained in:
commit
7a6c75339f
@ -146,6 +146,7 @@ class ClosedJaxpr:
|
||||
|
||||
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
|
||||
assert len(consts) == len(jaxpr.constvars)
|
||||
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
|
||||
self.jaxpr = jaxpr
|
||||
self.consts = list(consts)
|
||||
|
||||
|
@ -1979,3 +1979,41 @@ def _get_pspec_from_executable(
|
||||
out_partition_spec = _get_partition_spec(out_ppspec)
|
||||
in_partition_spec = _get_partition_spec(in_ppspec)
|
||||
return tuple(in_partition_spec), tuple(out_partition_spec)
|
||||
|
||||
def _pjit_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
) -> Tuple[dict, dict]:
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
||||
if num_res == 0:
|
||||
residual_shardings = []
|
||||
else:
|
||||
residual_shardings = [_UNSPECIFIED] * num_res
|
||||
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
||||
new_params_known = dict(params_known,
|
||||
in_shardings=tuple(in_shardings_known),
|
||||
out_shardings=(*out_shardings_known, *residual_shardings),
|
||||
donated_invars=tuple(donated_invars_known))
|
||||
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
||||
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
||||
|
||||
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
||||
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
||||
donated_invars_staged = [False] * num_res + donated_invars_staged
|
||||
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
||||
in_shardings_staged = [*residual_shardings, *in_shardings_staged]
|
||||
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
||||
new_params_staged = dict(params_staged,
|
||||
in_shardings=tuple(in_shardings_staged),
|
||||
out_shardings=tuple(out_shardings_staged),
|
||||
donated_invars=tuple(donated_invars_staged))
|
||||
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
|
||||
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
|
||||
return new_params_known, new_params_staged
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
|
||||
partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr',
|
||||
_pjit_partial_eval_custom_params_updater)
|
||||
|
@ -1257,34 +1257,37 @@ def call_partial_eval_custom_rule(
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
|
||||
def closed_call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str,
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr = convert_constvars_jaxpr(closed_jaxpr.jaxpr)
|
||||
unks_in = [False] * len(closed_jaxpr.consts) + list(unks_in)
|
||||
inst_in = [False] * len(closed_jaxpr.consts) + list(inst_in)
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||
partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
|
||||
partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in,
|
||||
False, False, saveable)
|
||||
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||
params_known = {**eqn.params, jaxpr_param_name: core.ClosedJaxpr(jaxpr_known,
|
||||
())}
|
||||
params_staged = {**eqn.params, jaxpr_param_name:
|
||||
core.ClosedJaxpr(jaxpr_staged, ())}
|
||||
residuals = [newvar(res_aval(params_known, var.aval))
|
||||
for var in jaxpr_staged.invars[:num_res]]
|
||||
newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr])
|
||||
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
|
||||
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
|
||||
params_known, params_staged = params_updater(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
|
||||
params_staged)
|
||||
residuals = [newvar(res_aval(params_known, a))
|
||||
for a in jaxpr_staged.in_avals[:num_res]]
|
||||
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
eqn.primitive, params_staged, jaxpr_staged.effects,
|
||||
eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
@ -1293,7 +1296,8 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
||||
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr')
|
||||
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
|
||||
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user