Don't call get_cur_mesh_sharding if sharding-in-types mode is not enabled

PiperOrigin-RevId: 724461150
This commit is contained in:
Yash Katariya 2025-02-07 13:55:01 -08:00 committed by jax authors
parent f21b0f03b4
commit 21e1be3320

View File

@ -1773,8 +1773,10 @@ def canonicalize_value(val):
def get_cur_mesh_sharding(spec=None):
from jax._src.sharding_impls import NamedSharding # type: ignore
if not config.sharding_in_types.value:
return None
from jax._src.sharding_impls import NamedSharding # type: ignore
spec = P() if spec is None else spec
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)