Plumb debug_info to meshExecutable as a optional arg to raise better error messages.

PiperOrigin-RevId: 525521694
This commit is contained in:
Yash Katariya 2023-04-19 12:35:15 -07:00 committed by jax authors
parent a2fbd59e63
commit 0a19638490
4 changed files with 60 additions and 23 deletions

View File

@ -2110,7 +2110,8 @@ def lower_sharding_computation(
backend=backend,
device_assignment=da_object,
committed=committed,
pmap_nreps=nreps)
pmap_nreps=nreps,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
def _to_logical_sharding(
@ -2291,7 +2292,8 @@ def lower_mesh_computation(
kept_var_idx=set(range(len(global_in_avals))),
backend=backend,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True)
committed=True,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
class MeshComputation(stages.XlaLowering):
_hlo: Optional[ir.Module]
@ -2580,6 +2582,7 @@ class UnloadedMeshExecutable:
host_callbacks: Sequence[Any]
kept_var_idx: Set[int]
auto_spmd_lowering: bool
jaxpr_debug_info: Optional[core.JaxprDebugInfo]
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
@ -2601,7 +2604,7 @@ class UnloadedMeshExecutable:
self.input_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self)
self.jaxpr_debug_info, self)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
@ -2627,6 +2630,7 @@ class UnloadedMeshExecutable:
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
committed: bool,
pmap_nreps: int = 1,
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None,
compiler_options=None
) -> MeshExecutable:
compiler_options_keys = tuple(
@ -2714,7 +2718,8 @@ class UnloadedMeshExecutable:
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
auto_spmd_lowering=auto_spmd_lowering).load()
auto_spmd_lowering=auto_spmd_lowering,
jaxpr_debug_info=jaxpr_debug_info).load()
class MeshExecutableFastpathData(NamedTuple):
@ -2731,12 +2736,12 @@ class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_unloaded_executable",
"_jaxpr_debug_info", "_unloaded_executable",
]
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
unloaded_executable=None):
jaxpr_debug_info=None, unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
@ -2747,6 +2752,7 @@ class MeshExecutable(stages.XlaExecutable):
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
@property
@ -2790,7 +2796,8 @@ class MeshExecutable(stages.XlaExecutable):
ref_avals = self.in_avals
check_arg_avals_for_call(ref_avals, arg_avals)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings)
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
self._jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@ -2968,10 +2975,14 @@ def check_device_backend_on_shardings(shardings) -> bool:
def check_gda_or_array_xla_sharding_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> None:
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
jaxpr_debug_info: Optional[core.JaxprDebugInfo]) -> None:
from jax._src.array import ArrayImpl
for arg, xs in safe_zip(args, in_xla_shardings):
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
jaxpr_debug_info.arg_names)
errors = []
num_errors = 5
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
if not isinstance(arg, ArrayImpl):
continue
@ -2982,10 +2993,17 @@ def check_gda_or_array_xla_sharding_match(
not op_shardings.are_op_shardings_equal(
arg.sharding._to_xla_op_sharding(arg.ndim),
xs._to_xla_op_sharding(arg.ndim))):
raise ValueError(
f"Array sharding does not match the input sharding. "
f"Got Array sharding: {arg.sharding} and xla sharding: {xs} for "
f"arg shape: {arg.shape}, arg value: {arg}")
errors.append(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")
if errors:
str_errors = '\n'.join(errors)
num_mismatch_str = (f'{len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
"Array(s) sharding does not match the input(s) sharding. "
f"Here are the {num_mismatch_str}:\n{str_errors}")
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:

View File

@ -1076,7 +1076,8 @@ def _pjit_call_impl(*args, jaxpr,
_most_recent_pjit_call_executable.value = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks:
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings)
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
jaxpr.jaxpr.debug_info)
if config.jax_distributed_debug:
# Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this

View File

@ -329,11 +329,13 @@ def _check_lowered(lowered) -> None:
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = ["backend", "mesh", "global_in_avals",
allowed_compile_args = [
"backend", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment"]
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info"]
for compile_arg in lowered.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -1229,7 +1229,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
input_data)
with self.assertRaisesRegex(
ValueError,
"Array sharding does not match the input sharding."):
r"Array\(s\) sharding does not match the input\(s\) "
r"sharding.*\n.*for arg x"):
compiled(arr)
def test_gda_auto_shardings_len(self):
@ -1530,16 +1531,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
with global_mesh:
f = pjit(
lambda x, y: x @ y.T,
lambda x, y, z, a, b, c: (x @ y.T, y, z, a, b, c),
in_shardings=NamedSharding(global_mesh, P('x' ,'y')))
compiled = f.lower(aval, aval).compile()
out = compiled(a1, a1)
compiled = f.lower(aval, aval, aval, aval, aval, aval).compile()
out, *_ = compiled(a1, a1, a1, a1, a1, a1)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out._value, input_data @ input_data.T)
with self.assertRaisesRegex(
ValueError, 'Array sharding does not match the input sharding'):
compiled(a2, a2)
ValueError,
r"Array\(s\) sharding does not match the input\(s\) sharding. "
"Here are the 5 mismatches out of 6"):
compiled(a2, a2, a2, a2, a2, a2)
with global_mesh:
f = pjit(lambda a: a, in_shardings=NamedSharding(global_mesh, P('x' ,'y')))
abstract_inp = {'x': aval, 'y': {'y1': aval}}
inp1 = {'x': a1, 'y': {'y1': a1}}
compiled = f.lower(abstract_inp).compile()
compiled(inp1)
inp2 = {'x': a2, 'y': {'y1': a2}}
with self.assertRaisesRegex(
ValueError,
r"Array\(s\) sharding does not match the input\(s\) sharding. "
"Here are the 2 mismatches"):
compiled(inp2)
def test_globally_sharded_key_array_result_8x4_single_device(self):
input_shape = (8, 4)