From 42751359e690f31fda62c583361d6207d43d2597 Mon Sep 17 00:00:00 2001
From: Yash Katariya <yashkatariya@google.com>
Date: Tue, 8 Apr 2025 18:59:34 -0700
Subject: [PATCH] Fix typo in the error message

PiperOrigin-RevId: 745375892
---
 jax/experimental/shard_map.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py
index c49a3de17..189a3d9b4 100644
--- a/jax/experimental/shard_map.py
+++ b/jax/experimental/shard_map.py
@@ -1427,8 +1427,8 @@ def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_):
   _, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts])
   carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in])
   if tuple(carry_rep_in) != tuple(carry_rep_out):
-    raise Exception("Scanwhile_loopcarry input and output got mismatched "
-                    "replication types {carry_rep_in} and {carry_rep_out}. "
+    raise Exception("while_loop carry input and output got mismatched "
+                    f"replication types {carry_rep_in} and {carry_rep_out}. "
                     "Please open an issue at "
                     "https://github.com/jax-ml/jax/issues, and as a temporary "
                     "workaround pass the check_rep=False argument to shard_map")