Fix type failures under an upcoming pytype change.

PiperOrigin-RevId: 522591195
This commit is contained in:
jax authors 2023-04-07 07:09:44 -07:00
parent 87a1fea1c7
commit b15ebb1bc5
2 changed files with 2 additions and 1 deletions

View File

@ -70,6 +70,7 @@ class custom_vmap:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None
out_flat = custom_vmap_p.bind(*consts, *args_flat,
call=closed_call,
rule=ClosedRule(self.vmap_rule),

View File

@ -1942,7 +1942,7 @@ class DeviceAssignmentMismatch:
return f"device ids {self.device_ids} on platform {self.platform}"
def m_type_str(self, api_name):
return (f'{self.source_info.eqn_name} inside {api_name}'
return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}'
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
def _str(self, api_name):