If sharding is not None (that's passed to convert_element_type), only compare it with operand's sharding if the sharding is concrete. Otherwise doing getattr(operand, 'sharding') on a Tracer leads to weird timeouts.

PiperOrigin-RevId: 723595960
This commit is contained in:
Yash Katariya 2025-02-05 11:49:11 -08:00 committed by jax authors
parent f43d2b68d9
commit 0fb278a0b9

View File

@ -841,7 +841,8 @@ def _convert_element_type(
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
isinstance(operand, Array) and
not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
(sharding is None or
(sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))):
return operand
else:
return convert_element_type_p.bind(