mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Summarize equations in sharding mismatch errors lazily.
Don't spend time building error output unless there is an error. PiperOrigin-RevId: 606812191
This commit is contained in:
parent
7156f20b44
commit
9bb6f18528
@ -213,7 +213,7 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr):
|
||||
|
||||
|
||||
class SourceInfo(NamedTuple):
|
||||
source_info: str
|
||||
source_info: source_info_util.SourceInfo
|
||||
eqn_name: str
|
||||
|
||||
|
||||
@ -225,17 +225,14 @@ def jaxpr_shardings(
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is pjit.sharding_constraint_p:
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield (eqn.params['sharding'], source_info)
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield from ((i, source_info) for i in eqn.params['in_shardings'])
|
||||
yield from ((o, source_info) for o in eqn.params['out_shardings'])
|
||||
elif eqn.primitive is shard_map.shard_map_p:
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
def _names_to_pspec(names):
|
||||
ndmin = max(names) + 1 if names else 0
|
||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||
@ -244,8 +241,7 @@ def jaxpr_shardings(
|
||||
elif eqn.primitive is device_put_p:
|
||||
s = eqn.params['device']
|
||||
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield (s, source_info)
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_shardings(subjaxpr)
|
||||
|
@ -1649,7 +1649,10 @@ class DeviceAssignmentMismatch:
|
||||
|
||||
@property
|
||||
def source_info_str(self):
|
||||
return "" if self.source_info is None else f" at {self.source_info.source_info}"
|
||||
return (
|
||||
"" if self.source_info is None
|
||||
else f" at {source_info_util.summarize(self.source_info.source_info)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def _dev_ids_plat_str(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user