diff --git a/docs/aot.md b/docs/aot.md index 9ddd1d4bd..1a7ec0080 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -97,18 +97,16 @@ lowering raises an error: >>> x_1d = y_1d = jnp.arange(3) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) ... -TypeError: Computation compiled for input types: - ShapedArray(int32[]), ShapedArray(int32[]) -called with: - ShapedArray(int32[3]), ShapedArray(int32[3]) +TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: +Argument 'x' compiled with int32[] and called with int32[3] +Argument 'y' compiled with int32[] and called with int32[3] ->>> x_f = y_f = 72.0 +>>> x_f = y_f = jnp.float32(72.) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) ... -TypeError: Computation compiled for input types: - ShapedArray(int32[]), ShapedArray(int32[]) -called with: - ShapedArray(float32[]), ShapedArray(float32[]) +TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: +Argument 'x' compiled with int32[] and called with float32[] +Argument 'y' compiled with int32[] and called with float32[] ``` Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 06988da83..5fec7eeb8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2851,22 +2851,29 @@ def check_arg_avals_for_call(ref_avals, arg_avals, raise TypeError( f"Computation compiled for {len(ref_avals)} inputs " f"but called with {len(arg_avals)}") - arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else - jaxpr_debug_info.arg_names) + + if jaxpr_debug_info is not None: + arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names] + else: + num_args = len(ref_avals) + arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)] + errors = [] - num_errors = 5 for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names): if not core.typematch(ref_aval, arg_aval): - errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for " - f"arg {name}") + errors.append( + f"Argument {name} compiled with {ref_aval.str_short()} and called " + f"with {arg_aval.str_short()}") if errors: - str_errors = '\n'.join(errors[:num_errors]) - num_mismatch_str = ( - f'the {len(errors)} mismatches' if len(errors) < num_errors else - f"{num_errors} mismatches out of {len(errors)}") + max_num_errors = 5 + str_errors = "\n".join(errors[:max_num_errors]) + if len(errors) >= max_num_errors: + num_mismatch_str = f"The first {max_num_errors} of {len(errors)}" + else: + num_mismatch_str = "The" raise TypeError( - "Computation was compiled for different input types and called with " - f"different types. Here are {num_mismatch_str}:\n{str_errors}") + "Argument types differ from the types for which this computation was " + f"compiled. {num_mismatch_str} mismatches are:\n{str_errors}") def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): diff --git a/tests/api_test.py b/tests/api_test.py index a87671033..2ea60e62b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -986,9 +986,9 @@ class CPPJitTest(jtu.BufferDonationTestCase): f_exe = self.jit(f).lower(x_f32).compile() self.assertRaisesRegex( TypeError, - r"Computation was compiled for different input types and called with " - r"different types. Here are the 1 mismatches:\n" - r"Compiled with.*float32.*and called with.*int32.*for arg x", + r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'x' compiled with.*float32.*and called with.*int32.*", lambda: f_exe(x_i32)) def test_jit_lower_compile_multi_arg(self): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e40173099..51035e080 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -952,10 +952,10 @@ class PJitTest(jtu.BufferDonationTestCase): exe = f.lower(x_f32, x_f32).compile() with self.assertRaisesRegex( TypeError, - r"Computation was compiled for different input types and called with " - r"different types. Here are the 2 mismatches:\n" - r"Compiled with.*float32.*and called with.*int32.*for arg x\n" - r"Compiled with.*float32.*and called with.*int32.*for arg y"): + r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'x' compiled with.*float32.*and called with.*int32.*\n" + r"Argument 'y' compiled with.*float32.*and called with.*int32.*"): exe(x_i32, x_i32) @jtu.with_mesh([('x', 2), ('y', 2)]) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index b1c515b12..33c3c72de 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -242,9 +242,9 @@ class PythonPmapTest(jtu.JaxTestCase): f_exe = f.lower(x_f32).compile() self.assertRaisesRegex( TypeError, - r"Computation was compiled for different input types and called with " - r"different types. Here are the 1 mismatches:\n" - r"Compiled with.*float32.*and called with.*int32.*for arg x", + r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'x' compiled with.*float32.*and called with.*int32.*", lambda: f_exe(x_i32)) def testLowerCompileMultiArg(self): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index e53f04623..349c542a6 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -708,9 +708,9 @@ class XMapTest(XMapTestCase): f_exe = f.lower(x_f32).compile() self.assertRaisesRegex( TypeError, - r"Computation was compiled for different input types and called with " - r"different types. Here are the 1 mismatches:\n" - r"Compiled with.*float32.*and called with.*int32.*", + r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 1/1 compiled with.*float32.*and called with.*int32.*", lambda: f_exe(x_i32)) def testLowerAsText(self):