Fix typo in the error message

PiperOrigin-RevId: 745375892
This commit is contained in:
Yash Katariya 2025-04-08 18:59:34 -07:00 committed by jax authors
parent f95f6a8bdb
commit 42751359e6

View File

@ -1427,8 +1427,8 @@ def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_):
_, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts])
carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in])
if tuple(carry_rep_in) != tuple(carry_rep_out):
raise Exception("Scanwhile_loopcarry input and output got mismatched "
"replication types {carry_rep_in} and {carry_rep_out}. "
raise Exception("while_loop carry input and output got mismatched "
f"replication types {carry_rep_in} and {carry_rep_out}. "
"Please open an issue at "
"https://github.com/jax-ml/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")