mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Fix type failures under an upcoming pytype change.
PiperOrigin-RevId: 522591195
This commit is contained in:
parent
87a1fea1c7
commit
b15ebb1bc5
@ -70,6 +70,7 @@ class custom_vmap:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
assert self.vmap_rule is not None
|
||||
out_flat = custom_vmap_p.bind(*consts, *args_flat,
|
||||
call=closed_call,
|
||||
rule=ClosedRule(self.vmap_rule),
|
||||
|
@ -1942,7 +1942,7 @@ class DeviceAssignmentMismatch:
|
||||
return f"device ids {self.device_ids} on platform {self.platform}"
|
||||
|
||||
def m_type_str(self, api_name):
|
||||
return (f'{self.source_info.eqn_name} inside {api_name}'
|
||||
return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}'
|
||||
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
|
||||
|
||||
def _str(self, api_name):
|
||||
|
Loading…
x
Reference in New Issue
Block a user