Just use jit as the string in error messages instead of jit and pjit based on resource_env. This is to start deprecating the need for with mesh and replace it with use_mesh(mesh).

PiperOrigin-RevId: 733959962
This commit is contained in:
Yash Katariya 2025-03-05 20:08:54 -08:00 committed by jax authors
parent ba5349f896
commit a67ab9fade
3 changed files with 21 additions and 25 deletions

View File

@ -199,10 +199,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
profiler = None
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, p.arg_names)
fun_name, fails, args_flat, 'jit', p.arg_names)
raise ValueError(msg) from None
except xla.InvalidInputException as e:
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
@ -591,13 +590,12 @@ def _infer_params_impl(
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
in_shardings_treedef = out_shardings_treedef = treedef
else:
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
in_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit')
for x in ji.in_shardings_leaves)
in_shardings_treedef = ji.in_shardings_treedef
out_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit')
for x in ji.out_shardings_leaves)
out_shardings_treedef = ji.out_shardings_treedef
@ -1760,12 +1758,10 @@ def _pjit_lower(
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
if resource_env is not None:
mesh, api_name = resource_env.physical_mesh, 'pjit'
else:
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
mesh = (resource_env.physical_mesh if resource_env is not None else
mesh_lib.get_concrete_mesh())
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
jaxpr, 'jit', name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused, context_mesh=mesh,
compiler_options_kvs=compiler_options_kvs,
@ -1929,7 +1925,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
func = _pjit_cached_lower_jaxpr_to_fun(
ctx, name, jaxpr, tuple(effects), in_shardings,
out_shardings, in_layouts, out_layouts,
api_name=('jit' if resource_env is None else 'pjit'))
api_name='jit')
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
args = (*ctx.dim_var_values, *tokens_in, *args)

View File

@ -263,9 +263,9 @@ class NameStackTransformationTest(jtu.JaxTestCase):
return g(x)
hlo_text = _get_hlo(f)(2.)
self.assertIn('jvp(pjit(f))/pjit(g)/sin', hlo_text)
self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text)
self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text)
self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text)
self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text)
self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text)
def test_remat_appears_in_hlo(self):
@ad_checkpoint.remat

View File

@ -2076,7 +2076,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with global_mesh:
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for pjitted computation"):
ValueError, "Received incompatible devices for jitted computation"):
pjit(lambda x: x)(input_array)
def test_array_lower_compile(self):
@ -2177,7 +2177,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with m1:
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for pjitted computation"):
ValueError, "Received incompatible devices for jitted computation"):
pjit(lambda x, y: (x, y),
out_shardings=(NamedSharding(m1, spec),
NamedSharding(m2, spec)))(a1, a1)
@ -2192,7 +2192,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with m1:
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for pjitted computation"):
ValueError, "Received incompatible devices for jitted computation"):
pjit(
lambda x, y: (x, y),
in_shardings=NamedSharding(m2, spec),
@ -2348,7 +2348,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
arr = jnp.array([1, 2, 3])
with self.assertRaisesRegex(
RuntimeError,
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or'
r' `None` to in_shardings.*'):
pjit(lambda x: x, in_shardings=P('x'))(arr)
@ -2396,7 +2396,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.create_mesh((2, 2), ('x', 'y')):
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation"):
"Received incompatible devices for jitted computation"):
pjit(lambda x, y: (x, y))(uarr, carr)
def test_pjit_uncommitted_array_multi_devices(self):
@ -2418,7 +2418,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
"Received incompatible devices for jitted computation. Got argument "
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
pjit(lambda x, y: (x, y))(a, b)
@ -2430,7 +2430,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
"Received incompatible devices for jitted computation. Got argument "
r"x\[0\] of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
r"argument x\[1\] of.*\<lambda\> with shape int.*\[3\] and device ids "
r"\[1\].*"):
@ -2443,7 +2443,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
c = jax.device_put(np.arange(16).reshape(8, 2),
NamedSharding(mesh, P('x', 'y')))
msg = ("Received incompatible devices for pjitted computation. Got "
msg = ("Received incompatible devices for jitted computation. Got "
r"argument {} of.*<lambda> with shape int.*\[3\] and device ids "
r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and "
r"device ids.*")
@ -2617,9 +2617,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
return f(inp1, inp2, inp3)
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
"Received incompatible devices for jitted computation. Got argument "
r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*"
r"pjit inside pjit with device ids.*"):
r"pjit inside jit with device ids.*"):
my_nested_pjit(committed_inp, committed_inp, committed_inp)
@jtu.ignore_warning(category=DeprecationWarning,
@ -7236,7 +7236,7 @@ class PJitErrorTest(jtu.JaxTestCase):
xshape = (2, 5, 6)
x = jnp.arange(math.prod(xshape)).reshape(xshape)
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for pjitted computation.*"):
ValueError, "Received incompatible devices for jitted computation.*"):
f(x)
@parameterized.named_parameters(