mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Don't call get_cur_mesh_sharding
if sharding-in-types mode is not enabled
PiperOrigin-RevId: 724461150
This commit is contained in:
parent
f21b0f03b4
commit
21e1be3320
@ -1773,8 +1773,10 @@ def canonicalize_value(val):
|
||||
|
||||
|
||||
def get_cur_mesh_sharding(spec=None):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
if not config.sharding_in_types.value:
|
||||
return None
|
||||
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
spec = P() if spec is None else spec
|
||||
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user