From 42751359e690f31fda62c583361d6207d43d2597 Mon Sep 17 00:00:00 2001 From: Yash Katariya <yashkatariya@google.com> Date: Tue, 8 Apr 2025 18:59:34 -0700 Subject: [PATCH] Fix typo in the error message PiperOrigin-RevId: 745375892 --- jax/experimental/shard_map.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c49a3de17..189a3d9b4 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1427,8 +1427,8 @@ def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_): _, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in]) if tuple(carry_rep_in) != tuple(carry_rep_out): - raise Exception("Scanwhile_loopcarry input and output got mismatched " - "replication types {carry_rep_in} and {carry_rep_out}. " + raise Exception("while_loop carry input and output got mismatched " + f"replication types {carry_rep_in} and {carry_rep_out}. " "Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map")