mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Plumb debug_info to meshExecutable as a optional arg to raise better error messages.
PiperOrigin-RevId: 525521694
This commit is contained in:
parent
a2fbd59e63
commit
0a19638490
@ -2110,7 +2110,8 @@ def lower_sharding_computation(
|
||||
backend=backend,
|
||||
device_assignment=da_object,
|
||||
committed=committed,
|
||||
pmap_nreps=nreps)
|
||||
pmap_nreps=nreps,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
@ -2291,7 +2292,8 @@ def lower_mesh_computation(
|
||||
kept_var_idx=set(range(len(global_in_avals))),
|
||||
backend=backend,
|
||||
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
||||
committed=True)
|
||||
committed=True,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: Optional[ir.Module]
|
||||
@ -2580,6 +2582,7 @@ class UnloadedMeshExecutable:
|
||||
host_callbacks: Sequence[Any]
|
||||
kept_var_idx: Set[int]
|
||||
auto_spmd_lowering: bool
|
||||
jaxpr_debug_info: Optional[core.JaxprDebugInfo]
|
||||
|
||||
def build_unsafe_call(self):
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
@ -2601,7 +2604,7 @@ class UnloadedMeshExecutable:
|
||||
self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self)
|
||||
self.jaxpr_debug_info, self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
@ -2627,6 +2630,7 @@ class UnloadedMeshExecutable:
|
||||
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
|
||||
committed: bool,
|
||||
pmap_nreps: int = 1,
|
||||
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None,
|
||||
compiler_options=None
|
||||
) -> MeshExecutable:
|
||||
compiler_options_keys = tuple(
|
||||
@ -2714,7 +2718,8 @@ class UnloadedMeshExecutable:
|
||||
keepalive=keepalive,
|
||||
host_callbacks=host_callbacks,
|
||||
kept_var_idx=kept_var_idx,
|
||||
auto_spmd_lowering=auto_spmd_lowering).load()
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
jaxpr_debug_info=jaxpr_debug_info).load()
|
||||
|
||||
|
||||
class MeshExecutableFastpathData(NamedTuple):
|
||||
@ -2731,12 +2736,12 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_unloaded_executable",
|
||||
"_jaxpr_debug_info", "_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
unloaded_executable=None):
|
||||
jaxpr_debug_info=None, unloaded_executable=None):
|
||||
self.xla_executable = xla_executable
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
@ -2747,6 +2752,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._jaxpr_debug_info = jaxpr_debug_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
@ -2790,7 +2796,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
ref_avals = self.in_avals
|
||||
check_arg_avals_for_call(ref_avals, arg_avals)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings)
|
||||
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
|
||||
self._jaxpr_debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
@ -2968,10 +2975,14 @@ def check_device_backend_on_shardings(shardings) -> bool:
|
||||
|
||||
|
||||
def check_gda_or_array_xla_sharding_match(
|
||||
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> None:
|
||||
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
jaxpr_debug_info: Optional[core.JaxprDebugInfo]) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
for arg, xs in safe_zip(args, in_xla_shardings):
|
||||
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
|
||||
jaxpr_debug_info.arg_names)
|
||||
errors = []
|
||||
num_errors = 5
|
||||
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
|
||||
@ -2982,10 +2993,17 @@ def check_gda_or_array_xla_sharding_match(
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
arg.sharding._to_xla_op_sharding(arg.ndim),
|
||||
xs._to_xla_op_sharding(arg.ndim))):
|
||||
raise ValueError(
|
||||
f"Array sharding does not match the input sharding. "
|
||||
f"Got Array sharding: {arg.sharding} and xla sharding: {xs} for "
|
||||
f"arg shape: {arg.shape}, arg value: {arg}")
|
||||
errors.append(
|
||||
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
|
||||
f"arg {name} with shape: {arg.aval.str_short()}")
|
||||
|
||||
if errors:
|
||||
str_errors = '\n'.join(errors)
|
||||
num_mismatch_str = (f'{len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
raise ValueError(
|
||||
"Array(s) sharding does not match the input(s) sharding. "
|
||||
f"Here are the {num_mismatch_str}:\n{str_errors}")
|
||||
|
||||
|
||||
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
|
@ -1076,7 +1076,8 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
_most_recent_pjit_call_executable.value = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
if compiled._auto_spmd_lowering and config.jax_enable_checks:
|
||||
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings)
|
||||
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
|
||||
jaxpr.jaxpr.debug_info)
|
||||
if config.jax_distributed_debug:
|
||||
# Defensively only perform fingerprint logic if debug logging is enabled
|
||||
# NOTE(skyewm): I didn't benchmark this
|
||||
|
@ -329,11 +329,13 @@ def _check_lowered(lowered) -> None:
|
||||
# Check that we do not see new compile_args. When we add a compile_args it is
|
||||
# safe to add it to the allowed_compile_args if it does not change the semantics
|
||||
# or the calling convention of the lowered module.
|
||||
allowed_compile_args = ["backend", "mesh", "global_in_avals",
|
||||
allowed_compile_args = [
|
||||
"backend", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment"]
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info"]
|
||||
for compile_arg in lowered.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
@ -1229,7 +1229,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
input_data)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Array sharding does not match the input sharding."):
|
||||
r"Array\(s\) sharding does not match the input\(s\) "
|
||||
r"sharding.*\n.*for arg x"):
|
||||
compiled(arr)
|
||||
|
||||
def test_gda_auto_shardings_len(self):
|
||||
@ -1530,16 +1531,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(
|
||||
lambda x, y: x @ y.T,
|
||||
lambda x, y, z, a, b, c: (x @ y.T, y, z, a, b, c),
|
||||
in_shardings=NamedSharding(global_mesh, P('x' ,'y')))
|
||||
compiled = f.lower(aval, aval).compile()
|
||||
out = compiled(a1, a1)
|
||||
compiled = f.lower(aval, aval, aval, aval, aval, aval).compile()
|
||||
out, *_ = compiled(a1, a1, a1, a1, a1, a1)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out._value, input_data @ input_data.T)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Array sharding does not match the input sharding'):
|
||||
compiled(a2, a2)
|
||||
ValueError,
|
||||
r"Array\(s\) sharding does not match the input\(s\) sharding. "
|
||||
"Here are the 5 mismatches out of 6"):
|
||||
compiled(a2, a2, a2, a2, a2, a2)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda a: a, in_shardings=NamedSharding(global_mesh, P('x' ,'y')))
|
||||
abstract_inp = {'x': aval, 'y': {'y1': aval}}
|
||||
inp1 = {'x': a1, 'y': {'y1': a1}}
|
||||
compiled = f.lower(abstract_inp).compile()
|
||||
compiled(inp1)
|
||||
inp2 = {'x': a2, 'y': {'y1': a2}}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Array\(s\) sharding does not match the input\(s\) sharding. "
|
||||
"Here are the 2 mismatches"):
|
||||
compiled(inp2)
|
||||
|
||||
def test_globally_sharded_key_array_result_8x4_single_device(self):
|
||||
input_shape = (8, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user