mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
This commit is contained in:
parent
968dbaf8f3
commit
53e6382f4a
@ -643,7 +643,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
|
||||
result_paths = trace_debug.result_paths() # type: ignore
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
trace_debug.traced_for, trace_debug.func_src_info,
|
||||
trace_debug.arg_names, result_paths)
|
||||
trace_debug.arg_names, tuple(result_paths))
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
|
||||
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
|
||||
|
@ -898,7 +898,8 @@ def lower_parallel_callable(
|
||||
shards=shards, tuple_args=tuple_args,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=keepalive, host_callbacks=host_callbacks)
|
||||
keepalive=keepalive, host_callbacks=host_callbacks,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
|
||||
class PmapComputation(stages.XlaLowering):
|
||||
@ -951,6 +952,34 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects: List[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
jaxpr_debug_info: core.JaxprDebugInfo
|
||||
|
||||
def build_execute_fun(self):
|
||||
input_indices = []
|
||||
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
||||
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
||||
assert isinstance(aval, core.ShapedArray), aval
|
||||
input_indices.append(
|
||||
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
|
||||
if spec.sharding_spec is not None else None)
|
||||
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
||||
self.output_shardings)
|
||||
handle_args = InputsHandler(self.compiled.local_devices(),
|
||||
self.input_shardings, input_indices)
|
||||
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
|
||||
self.backend, handle_args, handle_outs,
|
||||
self.unordered_effects,
|
||||
self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks),
|
||||
set(range(len(input_indices))))
|
||||
return execute_fun
|
||||
|
||||
def load(self) -> PmapExecutable:
|
||||
fingerprint = getattr(self.compiled, "fingerprint", None)
|
||||
|
||||
return PmapExecutable(
|
||||
self.compiled, self.build_execute_fun, fingerprint,
|
||||
self.local_input_avals, self.jaxpr_debug_info, self)
|
||||
|
||||
@staticmethod
|
||||
def from_hlo(xla_computation,
|
||||
@ -962,6 +991,7 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects: List[core.Effect],
|
||||
host_callbacks: List[Any],
|
||||
keepalive: Any,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo,
|
||||
compiler_options=None):
|
||||
devices = pci.devices
|
||||
if devices is None:
|
||||
@ -1056,7 +1086,7 @@ class UnloadedPmapExecutable:
|
||||
return _compile_replicated_pmap_executable_from_hlo(
|
||||
xla_computation, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, bool(unordered_effects),
|
||||
ordered_effects)
|
||||
ordered_effects, jaxpr_debug_info)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
|
||||
@ -1075,38 +1105,13 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=keepalive,
|
||||
host_callbacks=host_callbacks,
|
||||
).load()
|
||||
|
||||
def build_execute_fun(self):
|
||||
input_indices = []
|
||||
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
||||
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
||||
assert isinstance(aval, core.ShapedArray), aval
|
||||
input_indices.append(
|
||||
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
|
||||
if spec.sharding_spec is not None else None)
|
||||
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
||||
self.output_shardings)
|
||||
handle_args = InputsHandler(self.compiled.local_devices(),
|
||||
self.input_shardings, input_indices)
|
||||
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
|
||||
self.backend, handle_args, handle_outs,
|
||||
self.unordered_effects,
|
||||
self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks),
|
||||
set(range(len(input_indices))))
|
||||
return execute_fun
|
||||
|
||||
def load(self) -> PmapExecutable:
|
||||
fingerprint = getattr(self.compiled, "fingerprint", None)
|
||||
|
||||
return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint,
|
||||
self.local_input_avals, self)
|
||||
jaxpr_debug_info=jaxpr_debug_info).load()
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
xla_computation, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
|
||||
jaxpr_debug_info):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=xla_computation,
|
||||
@ -1116,20 +1121,23 @@ def _compile_replicated_pmap_executable_from_hlo(
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None)
|
||||
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
|
||||
jaxpr_debug_info, None)
|
||||
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
||||
"fingerprint", "in_avals", "_unloaded_executable"]
|
||||
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
||||
"_unloaded_executable"]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
|
||||
in_avals, unloaded_executable):
|
||||
in_avals, jaxpr_debug_info, unloaded_executable):
|
||||
self.xla_executable = xla_executable
|
||||
self._unsafe_call = None
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
self.fingerprint = fingerprint
|
||||
self.in_avals = in_avals
|
||||
self._jaxpr_debug_info = jaxpr_debug_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
@ -1147,7 +1155,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
def call(self, *args):
|
||||
# TODO(frostig): do we need to check sharding and sharded avals?
|
||||
arg_avals = map(xla.abstractify, args)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
|
||||
@ -2655,7 +2663,7 @@ class UnloadedMeshExecutable:
|
||||
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
|
||||
compile_options, tuple(host_callbacks), bool(unordered_effects),
|
||||
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
|
||||
pmap_nreps)
|
||||
pmap_nreps, jaxpr_debug_info)
|
||||
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
@ -2794,7 +2802,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
||||
arg_avals = map(xla.abstractify, kept_args)
|
||||
ref_avals = self.in_avals
|
||||
check_arg_avals_for_call(ref_avals, arg_avals)
|
||||
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
|
||||
self._jaxpr_debug_info)
|
||||
@ -2832,18 +2840,28 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals):
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
f"but called with {len(arg_avals)}")
|
||||
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
|
||||
arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else
|
||||
jaxpr_debug_info.arg_names)
|
||||
errors = []
|
||||
num_errors = 5
|
||||
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
|
||||
if not core.typematch(ref_aval, arg_aval):
|
||||
raise TypeError(
|
||||
errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for "
|
||||
f"arg {name}")
|
||||
if errors:
|
||||
str_errors = '\n'.join(errors[:num_errors])
|
||||
num_mismatch_str = (
|
||||
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
raise TypeError(
|
||||
"Computation was compiled for different input types and called with "
|
||||
"different types. One of the mismatches is:\n"
|
||||
f"Compiled with:\n {ref_aval}\n"
|
||||
f"called with:\n {arg_aval}")
|
||||
f"different types. Here are {num_mismatch_str}:\n{str_errors}")
|
||||
|
||||
|
||||
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
@ -2905,7 +2923,7 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
|
||||
semantics_out_shardings, auto_spmd_lowering, compile_options,
|
||||
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
|
||||
backend, da, committed, pmap_nreps):
|
||||
backend, da, committed, pmap_nreps, jaxpr_debug_info):
|
||||
assert not auto_spmd_lowering
|
||||
assert isinstance(da, _DeviceAssignment)
|
||||
in_shardings = semantics_in_shardings.shardings
|
||||
@ -2930,7 +2948,7 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, None)
|
||||
kept_var_idx, jaxpr_debug_info, None)
|
||||
|
||||
|
||||
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
@ -2998,12 +3016,13 @@ def check_gda_or_array_xla_sharding_match(
|
||||
f"arg {name} with shape: {arg.aval.str_short()}")
|
||||
|
||||
if errors:
|
||||
str_errors = '\n'.join(errors)
|
||||
num_mismatch_str = (f'{len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
str_errors = '\n'.join(errors[:num_errors])
|
||||
num_mismatch_str = (
|
||||
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
raise ValueError(
|
||||
"Array(s) sharding does not match the input(s) sharding. "
|
||||
f"Here are the {num_mismatch_str}:\n{str_errors}")
|
||||
f"Here are {num_mismatch_str}:\n{str_errors}")
|
||||
|
||||
|
||||
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
|
@ -992,10 +992,9 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f_exe = self.jit(f).lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Computation was compiled for different input types and called with "
|
||||
"different types. One of the mismatches is:\n"
|
||||
"Compiled with:\n.*float32.*\n"
|
||||
"called with:\n.*int32.*",
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 1 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg x",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def test_jit_lower_compile_multi_arg(self):
|
||||
|
@ -930,13 +930,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
x_f32 = x.astype(jnp.float32)
|
||||
x_i32 = x.astype(jnp.int32)
|
||||
exe = f.lower(x_f32, x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Computation was compiled for different input types and called with "
|
||||
"different types. One of the mismatches is:\n"
|
||||
"Compiled with:\n.*float32.*\n"
|
||||
"called with:\n.*int32.*",
|
||||
lambda: exe(x_i32, x_i32))
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 2 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg x\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg y"):
|
||||
exe(x_i32, x_i32)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testLowerAsText(self):
|
||||
@ -1541,7 +1541,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Array\(s\) sharding does not match the input\(s\) sharding. "
|
||||
"Here are the 5 mismatches out of 6"):
|
||||
"Here are 5 mismatches out of 6"):
|
||||
compiled(a2, a2, a2, a2, a2, a2)
|
||||
|
||||
with global_mesh:
|
||||
|
@ -240,10 +240,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Computation was compiled for different input types and called with "
|
||||
"different types. One of the mismatches is:\n"
|
||||
"Compiled with:\n.*float32.*\n"
|
||||
"called with:\n.*int32.*",
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 1 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*for arg x",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def testLowerCompileMultiArg(self):
|
||||
|
@ -707,11 +707,10 @@ class XMapTest(XMapTestCase):
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Computation was compiled for different input types and called with "
|
||||
"different types. One of the mismatches is:\n"
|
||||
"Compiled with:\n.*float32.*\n"
|
||||
"called with:\n.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
r"Computation was compiled for different input types and called with "
|
||||
r"different types. Here are the 1 mismatches:\n"
|
||||
r"Compiled with.*float32.*and called with.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def testLowerAsText(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
|
Loading…
x
Reference in New Issue
Block a user