From b15ebb1bc59caa22b1f9c483093dd045fa7797b7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 7 Apr 2023 07:09:44 -0700 Subject: [PATCH] Fix type failures under an upcoming pytype change. PiperOrigin-RevId: 522591195 --- jax/_src/custom_batching.py | 1 + jax/_src/interpreters/pxla.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 64fd57187..34d3ca9fa 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -70,6 +70,7 @@ class custom_vmap: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) in_tree = treedef_tuple((tree_structure(consts), in_tree)) + assert self.vmap_rule is not None out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=ClosedRule(self.vmap_rule), diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f56569b23..491926428 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1942,7 +1942,7 @@ class DeviceAssignmentMismatch: return f"device ids {self.device_ids} on platform {self.platform}" def m_type_str(self, api_name): - return (f'{self.source_info.eqn_name} inside {api_name}' + return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) def _str(self, api_name):