diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index af3be1379..3636176cd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -213,7 +213,7 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr): class SourceInfo(NamedTuple): - source_info: str + source_info: source_info_util.SourceInfo eqn_name: str @@ -225,17 +225,14 @@ def jaxpr_shardings( for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield (eqn.params['sharding'], source_info) elif eqn.primitive is pjit.pjit_p: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield from ((i, source_info) for i in eqn.params['in_shardings']) yield from ((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) def _names_to_pspec(names): ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) @@ -244,8 +241,7 @@ def jaxpr_shardings( elif eqn.primitive is device_put_p: s = eqn.params['device'] if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield (s, source_info) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7ab1ca605..06a166fb1 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1649,7 +1649,10 @@ class DeviceAssignmentMismatch: @property def source_info_str(self): - return "" if self.source_info is None else f" at {self.source_info.source_info}" + return ( + "" if self.source_info is None + else f" at {source_info_util.summarize(self.source_info.source_info)}" + ) @property def _dev_ids_plat_str(self):