From 9bb6f1852800591b2e8e0ecf08fce2b1f01a4237 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 Feb 2024 18:26:41 -0800 Subject: [PATCH] Summarize equations in sharding mismatch errors lazily. Don't spend time building error output unless there is an error. PiperOrigin-RevId: 606812191 --- jax/_src/dispatch.py | 14 +++++--------- jax/_src/interpreters/pxla.py | 5 ++++- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index af3be1379..3636176cd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -213,7 +213,7 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr): class SourceInfo(NamedTuple): - source_info: str + source_info: source_info_util.SourceInfo eqn_name: str @@ -225,17 +225,14 @@ def jaxpr_shardings( for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield (eqn.params['sharding'], source_info) elif eqn.primitive is pjit.pjit_p: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield from ((i, source_info) for i in eqn.params['in_shardings']) yield from ((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) def _names_to_pspec(names): ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) @@ -244,8 +241,7 @@ def jaxpr_shardings( elif eqn.primitive is device_put_p: s = eqn.params['device'] if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None: - source_info = SourceInfo(source_info_util.summarize(eqn.source_info), - eqn.primitive.name) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield (s, source_info) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7ab1ca605..06a166fb1 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1649,7 +1649,10 @@ class DeviceAssignmentMismatch: @property def source_info_str(self): - return "" if self.source_info is None else f" at {self.source_info.source_info}" + return ( + "" if self.source_info is None + else f" at {source_info_util.summarize(self.source_info.source_info)}" + ) @property def _dev_ids_plat_str(self):