mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add in_positional_semantics to new_params_known and new_params_staged otherwise it leads to length mismatch error down the stack. It is similar to donated_invars and in_shardings.
PiperOrigin-RevId: 502082828
This commit is contained in:
parent
38f91bdaa5
commit
4c58ef3840
@ -1544,6 +1544,8 @@ def _pjit_partial_eval_custom_params_updater(
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
||||
in_positional_semantics_known, _ = pe.partition_list(
|
||||
unks_in, params_known['in_positional_semantics'])
|
||||
if num_res == 0:
|
||||
residual_shardings = []
|
||||
else:
|
||||
@ -1552,7 +1554,8 @@ def _pjit_partial_eval_custom_params_updater(
|
||||
new_params_known = dict(params_known,
|
||||
in_shardings=tuple(in_shardings_known),
|
||||
out_shardings=(*out_shardings_known, *residual_shardings),
|
||||
donated_invars=tuple(donated_invars_known))
|
||||
donated_invars=tuple(donated_invars_known),
|
||||
in_positional_semantics=tuple(in_positional_semantics_known))
|
||||
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
||||
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
||||
|
||||
@ -1561,11 +1564,18 @@ def _pjit_partial_eval_custom_params_updater(
|
||||
donated_invars_staged = [False] * num_res + donated_invars_staged
|
||||
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
||||
in_shardings_staged = [*residual_shardings, *in_shardings_staged]
|
||||
_, in_positional_semantics_staged = pe.partition_list(
|
||||
inst_in, params_staged['in_positional_semantics'])
|
||||
in_positional_semantics_staged = [
|
||||
pxla._PositionalSemantics.GLOBAL] * num_res + in_positional_semantics_staged
|
||||
|
||||
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
||||
|
||||
new_params_staged = dict(params_staged,
|
||||
in_shardings=tuple(in_shardings_staged),
|
||||
out_shardings=tuple(out_shardings_staged),
|
||||
donated_invars=tuple(donated_invars_staged))
|
||||
donated_invars=tuple(donated_invars_staged),
|
||||
in_positional_semantics=tuple(in_positional_semantics_staged))
|
||||
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
|
||||
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
|
||||
return new_params_known, new_params_staged
|
||||
|
Loading…
x
Reference in New Issue
Block a user