Summarize equations in sharding mismatch errors lazily.

Don't spend time building error output unless there is an error.

PiperOrigin-RevId: 606812191
This commit is contained in:
Peter Hawkins 2024-02-13 18:26:41 -08:00 committed by jax authors
parent 7156f20b44
commit 9bb6f18528
2 changed files with 9 additions and 10 deletions

View File

@ -213,7 +213,7 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr):
class SourceInfo(NamedTuple):
source_info: str
source_info: source_info_util.SourceInfo
eqn_name: str
@ -225,17 +225,14 @@ def jaxpr_shardings(
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield (eqn.params['sharding'], source_info)
elif eqn.primitive is pjit.pjit_p:
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield from ((i, source_info) for i in eqn.params['in_shardings'])
yield from ((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
def _names_to_pspec(names):
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
@ -244,8 +241,7 @@ def jaxpr_shardings(
elif eqn.primitive is device_put_p:
s = eqn.params['device']
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield (s, source_info)
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr)

View File

@ -1649,7 +1649,10 @@ class DeviceAssignmentMismatch:
@property
def source_info_str(self):
return "" if self.source_info is None else f" at {self.source_info.source_info}"
return (
"" if self.source_info is None
else f" at {source_info_util.summarize(self.source_info.source_info)}"
)
@property
def _dev_ids_plat_str(self):