Merge pull request #16677 from froystig:aot-docs

PiperOrigin-RevId: 547078408
This commit is contained in:
jax authors 2023-07-10 22:38:39 -07:00
commit 392914bd46
6 changed files with 40 additions and 35 deletions

View File

@ -95,20 +95,18 @@ lowering raises an error:
```python
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_1d, y_1d)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
...
TypeError: Computation compiled for input types:
ShapedArray(int32[]), ShapedArray(int32[])
called with:
ShapedArray(int32[3]), ShapedArray(int32[3])
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]
>>> x_f = y_f = 72.0
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_f, y_f)
>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
...
TypeError: Computation compiled for input types:
ShapedArray(int32[]), ShapedArray(int32[])
called with:
ShapedArray(float32[]), ShapedArray(float32[])
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]
```
Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time

View File

@ -2851,22 +2851,29 @@ def check_arg_avals_for_call(ref_avals, arg_avals,
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else
jaxpr_debug_info.arg_names)
if jaxpr_debug_info is not None:
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
else:
num_args = len(ref_avals)
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
errors = []
num_errors = 5
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
if not core.typematch(ref_aval, arg_aval):
errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for "
f"arg {name}")
errors.append(
f"Argument {name} compiled with {ref_aval.str_short()} and called "
f"with {arg_aval.str_short()}")
if errors:
str_errors = '\n'.join(errors[:num_errors])
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
max_num_errors = 5
str_errors = "\n".join(errors[:max_num_errors])
if len(errors) >= max_num_errors:
num_mismatch_str = f"The first {max_num_errors} of {len(errors)}"
else:
num_mismatch_str = "The"
raise TypeError(
"Computation was compiled for different input types and called with "
f"different types. Here are {num_mismatch_str}:\n{str_errors}")
"Argument types differ from the types for which this computation was "
f"compiled. {num_mismatch_str} mismatches are:\n{str_errors}")
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):

View File

@ -986,9 +986,9 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f_exe = self.jit(f).lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Computation was compiled for different input types and called with "
r"different types. Here are the 1 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*for arg x",
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
def test_jit_lower_compile_multi_arg(self):

View File

@ -952,10 +952,10 @@ class PJitTest(jtu.BufferDonationTestCase):
exe = f.lower(x_f32, x_f32).compile()
with self.assertRaisesRegex(
TypeError,
r"Computation was compiled for different input types and called with "
r"different types. Here are the 2 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*for arg x\n"
r"Compiled with.*float32.*and called with.*int32.*for arg y"):
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*\n"
r"Argument 'y' compiled with.*float32.*and called with.*int32.*"):
exe(x_i32, x_i32)
@jtu.with_mesh([('x', 2), ('y', 2)])

View File

@ -242,9 +242,9 @@ class PythonPmapTest(jtu.JaxTestCase):
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Computation was compiled for different input types and called with "
r"different types. Here are the 1 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*for arg x",
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
def testLowerCompileMultiArg(self):

View File

@ -708,9 +708,9 @@ class XMapTest(XMapTestCase):
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Computation was compiled for different input types and called with "
r"different types. Here are the 1 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*",
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 1/1 compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
def testLowerAsText(self):