mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16677 from froystig:aot-docs
PiperOrigin-RevId: 547078408
This commit is contained in:
commit
392914bd46
20
docs/aot.md
20
docs/aot.md
@ -95,20 +95,18 @@ lowering raises an error:
|
||||
|
||||
```python
|
||||
>>> x_1d = y_1d = jnp.arange(3)
|
||||
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_1d, y_1d)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
|
||||
...
|
||||
TypeError: Computation compiled for input types:
|
||||
ShapedArray(int32[]), ShapedArray(int32[])
|
||||
called with:
|
||||
ShapedArray(int32[3]), ShapedArray(int32[3])
|
||||
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
|
||||
Argument 'x' compiled with int32[] and called with int32[3]
|
||||
Argument 'y' compiled with int32[] and called with int32[3]
|
||||
|
||||
>>> x_f = y_f = 72.0
|
||||
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_f, y_f)
|
||||
>>> x_f = y_f = jnp.float32(72.)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
|
||||
...
|
||||
TypeError: Computation compiled for input types:
|
||||
ShapedArray(int32[]), ShapedArray(int32[])
|
||||
called with:
|
||||
ShapedArray(float32[]), ShapedArray(float32[])
|
||||
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
|
||||
Argument 'x' compiled with int32[] and called with float32[]
|
||||
Argument 'y' compiled with int32[] and called with float32[]
|
||||
```
|
||||
|
||||
Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time
|
||||
|
@ -2851,22 +2851,29 @@ def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
f"but called with {len(arg_avals)}")
|
||||
arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else
|
||||
jaxpr_debug_info.arg_names)
|
||||
|
||||
if jaxpr_debug_info is not None:
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
|
||||
else:
|
||||
num_args = len(ref_avals)
|
||||
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
|
||||
|
||||
errors = []
|
||||
num_errors = 5
|
||||
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
|
||||
if not core.typematch(ref_aval, arg_aval):
|
||||
errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for "
|
||||
f"arg {name}")
|
||||
errors.append(
|
||||
f"Argument {name} compiled with {ref_aval.str_short()} and called "
|
||||
f"with {arg_aval.str_short()}")
|
||||
if errors:
|
||||
str_errors = '\n'.join(errors[:num_errors])
|
||||
num_mismatch_str = (
|
||||
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
max_num_errors = 5
|
||||
str_errors = "\n".join(errors[:max_num_errors])
|
||||
if len(errors) >= max_num_errors:
|
||||
num_mismatch_str = f"The first {max_num_errors} of {len(errors)}"
|
||||
else:
|
||||
num_mismatch_str = "The"
|
||||
raise TypeError(
|
||||
"Computation was compiled for different input types and called with "
|
||||
f"different types. Here are {num_mismatch_str}:\n{str_errors}")
|
||||
"Argument types differ from the types for which this computation was "
|
||||
f"compiled. {num_mismatch_str} mismatches are:\n{str_errors}")
|
||||
|
||||
|
||||
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
|
@ -986,9 +986,9 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f_exe = self.jit(f).lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 1 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg x",
|
||||
r"Argument types differ .*"
|
||||
r"The mismatches are:\n"
|
||||
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def test_jit_lower_compile_multi_arg(self):
|
||||
|
@ -952,10 +952,10 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
exe = f.lower(x_f32, x_f32).compile()
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 2 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg x\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg y"):
|
||||
r"Argument types differ .*"
|
||||
r"The mismatches are:\n"
|
||||
r"Argument 'x' compiled with.*float32.*and called with.*int32.*\n"
|
||||
r"Argument 'y' compiled with.*float32.*and called with.*int32.*"):
|
||||
exe(x_i32, x_i32)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
|
@ -242,9 +242,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 1 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg x",
|
||||
r"Argument types differ .*"
|
||||
r"The mismatches are:\n"
|
||||
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def testLowerCompileMultiArg(self):
|
||||
|
@ -708,9 +708,9 @@ class XMapTest(XMapTestCase):
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 1 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*",
|
||||
r"Argument types differ .*"
|
||||
r"The mismatches are:\n"
|
||||
r"Argument 1/1 compiled with.*float32.*and called with.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def testLowerAsText(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user