mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Just use jit
as the string in error messages instead of jit
and pjit
based on resource_env. This is to start deprecating the need for with mesh
and replace it with use_mesh(mesh)
.
PiperOrigin-RevId: 733959962
This commit is contained in:
parent
ba5349f896
commit
a67ab9fade
@ -199,10 +199,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
|
||||
profiler = None
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, p.arg_names)
|
||||
fun_name, fails, args_flat, 'jit', p.arg_names)
|
||||
raise ValueError(msg) from None
|
||||
except xla.InvalidInputException as e:
|
||||
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
|
||||
@ -591,13 +590,12 @@ def _infer_params_impl(
|
||||
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
|
||||
in_shardings_treedef = out_shardings_treedef = treedef
|
||||
else:
|
||||
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
|
||||
in_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
|
||||
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit')
|
||||
for x in ji.in_shardings_leaves)
|
||||
in_shardings_treedef = ji.in_shardings_treedef
|
||||
out_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
|
||||
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit')
|
||||
for x in ji.out_shardings_leaves)
|
||||
out_shardings_treedef = ji.out_shardings_treedef
|
||||
|
||||
@ -1760,12 +1758,10 @@ def _pjit_lower(
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
if resource_env is not None:
|
||||
mesh, api_name = resource_env.physical_mesh, 'pjit'
|
||||
else:
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
mesh = (resource_env.physical_mesh if resource_env is not None else
|
||||
mesh_lib.get_concrete_mesh())
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
jaxpr, 'jit', name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
keep_unused=keep_unused, context_mesh=mesh,
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
@ -1929,7 +1925,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
|
||||
func = _pjit_cached_lower_jaxpr_to_fun(
|
||||
ctx, name, jaxpr, tuple(effects), in_shardings,
|
||||
out_shardings, in_layouts, out_layouts,
|
||||
api_name=('jit' if resource_env is None else 'pjit'))
|
||||
api_name='jit')
|
||||
|
||||
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
|
||||
args = (*ctx.dim_var_values, *tokens_in, *args)
|
||||
|
@ -263,9 +263,9 @@ class NameStackTransformationTest(jtu.JaxTestCase):
|
||||
return g(x)
|
||||
|
||||
hlo_text = _get_hlo(f)(2.)
|
||||
self.assertIn('jvp(pjit(f))/pjit(g)/sin', hlo_text)
|
||||
self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text)
|
||||
self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text)
|
||||
self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text)
|
||||
self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text)
|
||||
self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text)
|
||||
|
||||
def test_remat_appears_in_hlo(self):
|
||||
@ad_checkpoint.remat
|
||||
|
@ -2076,7 +2076,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with global_mesh:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Received incompatible devices for pjitted computation"):
|
||||
ValueError, "Received incompatible devices for jitted computation"):
|
||||
pjit(lambda x: x)(input_array)
|
||||
|
||||
def test_array_lower_compile(self):
|
||||
@ -2177,7 +2177,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with m1:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Received incompatible devices for pjitted computation"):
|
||||
ValueError, "Received incompatible devices for jitted computation"):
|
||||
pjit(lambda x, y: (x, y),
|
||||
out_shardings=(NamedSharding(m1, spec),
|
||||
NamedSharding(m2, spec)))(a1, a1)
|
||||
@ -2192,7 +2192,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with m1:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Received incompatible devices for pjitted computation"):
|
||||
ValueError, "Received incompatible devices for jitted computation"):
|
||||
pjit(
|
||||
lambda x, y: (x, y),
|
||||
in_shardings=NamedSharding(m2, spec),
|
||||
@ -2348,7 +2348,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
arr = jnp.array([1, 2, 3])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
|
||||
r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or'
|
||||
r' `None` to in_shardings.*'):
|
||||
pjit(lambda x: x, in_shardings=P('x'))(arr)
|
||||
|
||||
@ -2396,7 +2396,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.create_mesh((2, 2), ('x', 'y')):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation"):
|
||||
"Received incompatible devices for jitted computation"):
|
||||
pjit(lambda x, y: (x, y))(uarr, carr)
|
||||
|
||||
def test_pjit_uncommitted_array_multi_devices(self):
|
||||
@ -2418,7 +2418,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
"Received incompatible devices for jitted computation. Got argument "
|
||||
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
|
||||
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
|
||||
pjit(lambda x, y: (x, y))(a, b)
|
||||
@ -2430,7 +2430,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
"Received incompatible devices for jitted computation. Got argument "
|
||||
r"x\[0\] of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
|
||||
r"argument x\[1\] of.*\<lambda\> with shape int.*\[3\] and device ids "
|
||||
r"\[1\].*"):
|
||||
@ -2443,7 +2443,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
c = jax.device_put(np.arange(16).reshape(8, 2),
|
||||
NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
msg = ("Received incompatible devices for pjitted computation. Got "
|
||||
msg = ("Received incompatible devices for jitted computation. Got "
|
||||
r"argument {} of.*<lambda> with shape int.*\[3\] and device ids "
|
||||
r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and "
|
||||
r"device ids.*")
|
||||
@ -2617,9 +2617,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return f(inp1, inp2, inp3)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
"Received incompatible devices for jitted computation. Got argument "
|
||||
r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"pjit inside pjit with device ids.*"):
|
||||
r"pjit inside jit with device ids.*"):
|
||||
my_nested_pjit(committed_inp, committed_inp, committed_inp)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
@ -7236,7 +7236,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
xshape = (2, 5, 6)
|
||||
x = jnp.arange(math.prod(xshape)).reshape(xshape)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Received incompatible devices for pjitted computation.*"):
|
||||
ValueError, "Received incompatible devices for jitted computation.*"):
|
||||
f(x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user