From a67ab9fade345cdb60bdece5bcf6b97793938966 Mon Sep 17 00:00:00 2001
From: Yash Katariya <yashkatariya@google.com>
Date: Wed, 5 Mar 2025 20:08:54 -0800
Subject: [PATCH] Just use `jit` as the string in error messages instead of
 `jit` and `pjit` based on resource_env. This is to start deprecating the need
 for `with mesh` and replace it with `use_mesh(mesh)`.

PiperOrigin-RevId: 733959962
---
 jax/_src/pjit.py         | 18 +++++++-----------
 tests/name_stack_test.py |  6 +++---
 tests/pjit_test.py       | 22 +++++++++++-----------
 3 files changed, 21 insertions(+), 25 deletions(-)

diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index dad5c949a..86df66301 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -199,10 +199,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
       profiler = None
   except pxla.DeviceAssignmentMismatchError as e:
     fails, = e.args
-    api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
     fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
     msg = _device_assignment_mismatch_error(
-        fun_name, fails, args_flat, api_name, p.arg_names)
+        fun_name, fails, args_flat, 'jit', p.arg_names)
     raise ValueError(msg) from None
   except xla.InvalidInputException as e:
     arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
@@ -591,13 +590,12 @@ def _infer_params_impl(
     in_shardings_leaves = out_shardings_leaves = tuple(leaves)
     in_shardings_treedef = out_shardings_treedef = treedef
   else:
-    jit_name = 'pjit' if pjit_mesh is not None else 'jit'
     in_shardings_leaves = tuple(
-        _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
+        _create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit')
         for x in ji.in_shardings_leaves)
     in_shardings_treedef = ji.in_shardings_treedef
     out_shardings_leaves = tuple(
-        _create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
+        _create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit')
         for x in ji.out_shardings_leaves)
     out_shardings_treedef = ji.out_shardings_treedef
 
@@ -1760,12 +1758,10 @@ def _pjit_lower(
     lowering_parameters: mlir.LoweringParameters,
     pgle_profiler: profiler.PGLEProfiler | None):
   util.test_event("pjit_lower")
-  if resource_env is not None:
-    mesh, api_name = resource_env.physical_mesh, 'pjit'
-  else:
-    mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
+  mesh = (resource_env.physical_mesh if resource_env is not None else
+          mesh_lib.get_concrete_mesh())
   return pxla.lower_sharding_computation(
-      jaxpr, api_name, name, in_shardings, out_shardings,
+      jaxpr, 'jit', name, in_shardings, out_shardings,
       in_layouts, out_layouts, tuple(donated_invars),
       keep_unused=keep_unused, context_mesh=mesh,
       compiler_options_kvs=compiler_options_kvs,
@@ -1929,7 +1925,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
   func = _pjit_cached_lower_jaxpr_to_fun(
       ctx, name, jaxpr, tuple(effects), in_shardings,
       out_shardings, in_layouts, out_layouts,
-      api_name=('jit' if resource_env is None else 'pjit'))
+      api_name='jit')
 
   tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
   args = (*ctx.dim_var_values, *tokens_in, *args)
diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py
index 270707934..f371c431e 100644
--- a/tests/name_stack_test.py
+++ b/tests/name_stack_test.py
@@ -263,9 +263,9 @@ class NameStackTransformationTest(jtu.JaxTestCase):
       return g(x)
 
     hlo_text = _get_hlo(f)(2.)
-    self.assertIn('jvp(pjit(f))/pjit(g)/sin', hlo_text)
-    self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text)
-    self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text)
+    self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text)
+    self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text)
+    self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text)
 
   def test_remat_appears_in_hlo(self):
     @ad_checkpoint.remat
diff --git a/tests/pjit_test.py b/tests/pjit_test.py
index 7a7f6b7d6..4c20ac649 100644
--- a/tests/pjit_test.py
+++ b/tests/pjit_test.py
@@ -2076,7 +2076,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
 
     with global_mesh:
       with self.assertRaisesRegex(
-          ValueError, "Received incompatible devices for pjitted computation"):
+          ValueError, "Received incompatible devices for jitted computation"):
         pjit(lambda x: x)(input_array)
 
   def test_array_lower_compile(self):
@@ -2177,7 +2177,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
 
     with m1:
       with self.assertRaisesRegex(
-          ValueError, "Received incompatible devices for pjitted computation"):
+          ValueError, "Received incompatible devices for jitted computation"):
         pjit(lambda x, y: (x, y),
               out_shardings=(NamedSharding(m1, spec),
                              NamedSharding(m2, spec)))(a1, a1)
@@ -2192,7 +2192,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
 
     with m1:
       with self.assertRaisesRegex(
-          ValueError, "Received incompatible devices for pjitted computation"):
+          ValueError, "Received incompatible devices for jitted computation"):
         pjit(
             lambda x, y: (x, y),
             in_shardings=NamedSharding(m2, spec),
@@ -2348,7 +2348,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     arr = jnp.array([1, 2, 3])
     with self.assertRaisesRegex(
         RuntimeError,
-        r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
+        r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or'
         r' `None` to in_shardings.*'):
       pjit(lambda x: x, in_shardings=P('x'))(arr)
 
@@ -2396,7 +2396,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     with jtu.create_mesh((2, 2), ('x', 'y')):
       with self.assertRaisesRegex(
           ValueError,
-          "Received incompatible devices for pjitted computation"):
+          "Received incompatible devices for jitted computation"):
         pjit(lambda x, y: (x, y))(uarr, carr)
 
   def test_pjit_uncommitted_array_multi_devices(self):
@@ -2418,7 +2418,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
     with self.assertRaisesRegex(
         ValueError,
-        "Received incompatible devices for pjitted computation. Got argument "
+        "Received incompatible devices for jitted computation. Got argument "
         r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
         r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
       pjit(lambda x, y: (x, y))(a, b)
@@ -2430,7 +2430,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
     with self.assertRaisesRegex(
         ValueError,
-        "Received incompatible devices for pjitted computation. Got argument "
+        "Received incompatible devices for jitted computation. Got argument "
         r"x\[0\] of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
         r"argument x\[1\] of.*\<lambda\> with shape int.*\[3\] and device ids "
         r"\[1\].*"):
@@ -2443,7 +2443,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     c = jax.device_put(np.arange(16).reshape(8, 2),
                        NamedSharding(mesh, P('x', 'y')))
 
-    msg = ("Received incompatible devices for pjitted computation. Got "
+    msg = ("Received incompatible devices for jitted computation. Got "
            r"argument {} of.*<lambda> with shape int.*\[3\] and device ids "
            r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and "
            r"device ids.*")
@@ -2617,9 +2617,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
       return f(inp1, inp2, inp3)
     with self.assertRaisesRegex(
         ValueError,
-        "Received incompatible devices for pjitted computation. Got argument "
+        "Received incompatible devices for jitted computation. Got argument "
         r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*"
-        r"pjit inside pjit with device ids.*"):
+        r"pjit inside jit with device ids.*"):
       my_nested_pjit(committed_inp, committed_inp, committed_inp)
 
   @jtu.ignore_warning(category=DeprecationWarning,
@@ -7236,7 +7236,7 @@ class PJitErrorTest(jtu.JaxTestCase):
     xshape = (2, 5, 6)
     x = jnp.arange(math.prod(xshape)).reshape(xshape)
     with self.assertRaisesRegex(
-        ValueError, "Received incompatible devices for pjitted computation.*"):
+        ValueError, "Received incompatible devices for jitted computation.*"):
       f(x)
 
   @parameterized.named_parameters(