Add in_positional_semantics to new_params_known and new_params_staged otherwise it leads to length mismatch error down the stack. It is similar to donated_invars and in_shardings.

PiperOrigin-RevId: 502082828
This commit is contained in:
Yash Katariya 2023-01-14 10:18:28 -08:00 committed by jax authors
parent 38f91bdaa5
commit 4c58ef3840

View File

@ -1544,6 +1544,8 @@ def _pjit_partial_eval_custom_params_updater(
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
in_positional_semantics_known, _ = pe.partition_list(
unks_in, params_known['in_positional_semantics'])
if num_res == 0:
residual_shardings = []
else:
@ -1552,7 +1554,8 @@ def _pjit_partial_eval_custom_params_updater(
new_params_known = dict(params_known,
in_shardings=tuple(in_shardings_known),
out_shardings=(*out_shardings_known, *residual_shardings),
donated_invars=tuple(donated_invars_known))
donated_invars=tuple(donated_invars_known),
in_positional_semantics=tuple(in_positional_semantics_known))
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
@ -1561,11 +1564,18 @@ def _pjit_partial_eval_custom_params_updater(
donated_invars_staged = [False] * num_res + donated_invars_staged
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
in_shardings_staged = [*residual_shardings, *in_shardings_staged]
_, in_positional_semantics_staged = pe.partition_list(
inst_in, params_staged['in_positional_semantics'])
in_positional_semantics_staged = [
pxla._PositionalSemantics.GLOBAL] * num_res + in_positional_semantics_staged
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
new_params_staged = dict(params_staged,
in_shardings=tuple(in_shardings_staged),
out_shardings=tuple(out_shardings_staged),
donated_invars=tuple(donated_invars_staged))
donated_invars=tuple(donated_invars_staged),
in_positional_semantics=tuple(in_positional_semantics_staged))
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
return new_params_known, new_params_staged