mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix typo in the error message
PiperOrigin-RevId: 745375892
This commit is contained in:
parent
f95f6a8bdb
commit
42751359e6
@ -1427,8 +1427,8 @@ def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_):
|
||||
_, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts])
|
||||
carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in])
|
||||
if tuple(carry_rep_in) != tuple(carry_rep_out):
|
||||
raise Exception("Scanwhile_loopcarry input and output got mismatched "
|
||||
"replication types {carry_rep_in} and {carry_rep_out}. "
|
||||
raise Exception("while_loop carry input and output got mismatched "
|
||||
f"replication types {carry_rep_in} and {carry_rep_out}. "
|
||||
"Please open an issue at "
|
||||
"https://github.com/jax-ml/jax/issues, and as a temporary "
|
||||
"workaround pass the check_rep=False argument to shard_map")
|
||||
|
Loading…
x
Reference in New Issue
Block a user