mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
If sharding is not None (that's passed to convert_element_type), only compare it with operand's sharding if the sharding is concrete. Otherwise doing getattr(operand, 'sharding')
on a Tracer
leads to weird timeouts.
PiperOrigin-RevId: 723595960
This commit is contained in:
parent
f43d2b68d9
commit
0fb278a0b9
@ -841,7 +841,8 @@ def _convert_element_type(
|
||||
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
|
||||
isinstance(operand, Array) and
|
||||
not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and
|
||||
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
|
||||
(sharding is None or
|
||||
(sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))):
|
||||
return operand
|
||||
else:
|
||||
return convert_element_type_p.bind(
|
||||
|
Loading…
x
Reference in New Issue
Block a user