From a67ab9fade345cdb60bdece5bcf6b97793938966 Mon Sep 17 00:00:00 2001 From: Yash Katariya <yashkatariya@google.com> Date: Wed, 5 Mar 2025 20:08:54 -0800 Subject: [PATCH] Just use `jit` as the string in error messages instead of `jit` and `pjit` based on resource_env. This is to start deprecating the need for `with mesh` and replace it with `use_mesh(mesh)`. PiperOrigin-RevId: 733959962 --- jax/_src/pjit.py | 18 +++++++----------- tests/name_stack_test.py | 6 +++--- tests/pjit_test.py | 22 +++++++++++----------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index dad5c949a..86df66301 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -199,10 +199,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, p.arg_names) + fun_name, fails, args_flat, 'jit', p.arg_names) raise ValueError(msg) from None except xla.InvalidInputException as e: arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names @@ -591,13 +590,12 @@ def _infer_params_impl( in_shardings_leaves = out_shardings_leaves = tuple(leaves) in_shardings_treedef = out_shardings_treedef = treedef else: - jit_name = 'pjit' if pjit_mesh is not None else 'jit' in_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name) + _create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit') for x in ji.in_shardings_leaves) in_shardings_treedef = ji.in_shardings_treedef out_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name) + _create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit') for x in ji.out_shardings_leaves) out_shardings_treedef = ji.out_shardings_treedef @@ -1760,12 +1758,10 @@ def _pjit_lower( lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): util.test_event("pjit_lower") - if resource_env is not None: - mesh, api_name = resource_env.physical_mesh, 'pjit' - else: - mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' + mesh = (resource_env.physical_mesh if resource_env is not None else + mesh_lib.get_concrete_mesh()) return pxla.lower_sharding_computation( - jaxpr, api_name, name, in_shardings, out_shardings, + jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), keep_unused=keep_unused, context_mesh=mesh, compiler_options_kvs=compiler_options_kvs, @@ -1929,7 +1925,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, func = _pjit_cached_lower_jaxpr_to_fun( ctx, name, jaxpr, tuple(effects), in_shardings, out_shardings, in_layouts, out_layouts, - api_name=('jit' if resource_env is None else 'pjit')) + api_name='jit') tokens_in = [ctx.tokens_in.get(eff) for eff in effects] args = (*ctx.dim_var_values, *tokens_in, *args) diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 270707934..f371c431e 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -263,9 +263,9 @@ class NameStackTransformationTest(jtu.JaxTestCase): return g(x) hlo_text = _get_hlo(f)(2.) - self.assertIn('jvp(pjit(f))/pjit(g)/sin', hlo_text) - self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text) - self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text) + self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text) + self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text) + self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) def test_remat_appears_in_hlo(self): @ad_checkpoint.remat diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7a7f6b7d6..4c20ac649 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2076,7 +2076,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with global_mesh: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit(lambda x: x)(input_array) def test_array_lower_compile(self): @@ -2177,7 +2177,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with m1: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit(lambda x, y: (x, y), out_shardings=(NamedSharding(m1, spec), NamedSharding(m2, spec)))(a1, a1) @@ -2192,7 +2192,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with m1: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit( lambda x, y: (x, y), in_shardings=NamedSharding(m2, spec), @@ -2348,7 +2348,7 @@ class ArrayPjitTest(jtu.JaxTestCase): arr = jnp.array([1, 2, 3]) with self.assertRaisesRegex( RuntimeError, - r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or' + r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or' r' `None` to in_shardings.*'): pjit(lambda x: x, in_shardings=P('x'))(arr) @@ -2396,7 +2396,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with jtu.create_mesh((2, 2), ('x', 'y')): with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation"): + "Received incompatible devices for jitted computation"): pjit(lambda x, y: (x, y))(uarr, carr) def test_pjit_uncommitted_array_multi_devices(self): @@ -2418,7 +2418,7 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and " r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"): pjit(lambda x, y: (x, y))(a, b) @@ -2430,7 +2430,7 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"x\[0\] of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and " r"argument x\[1\] of.*\<lambda\> with shape int.*\[3\] and device ids " r"\[1\].*"): @@ -2443,7 +2443,7 @@ class ArrayPjitTest(jtu.JaxTestCase): c = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) - msg = ("Received incompatible devices for pjitted computation. Got " + msg = ("Received incompatible devices for jitted computation. Got " r"argument {} of.*<lambda> with shape int.*\[3\] and device ids " r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and " r"device ids.*") @@ -2617,9 +2617,9 @@ class ArrayPjitTest(jtu.JaxTestCase): return f(inp1, inp2, inp3) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" - r"pjit inside pjit with device ids.*"): + r"pjit inside jit with device ids.*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jtu.ignore_warning(category=DeprecationWarning, @@ -7236,7 +7236,7 @@ class PJitErrorTest(jtu.JaxTestCase): xshape = (2, 5, 6) x = jnp.arange(math.prod(xshape)).reshape(xshape) with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation.*"): + ValueError, "Received incompatible devices for jitted computation.*"): f(x) @parameterized.named_parameters(