mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Check the type of mesh in use_abstract_mesh
and use_concrete_mesh
PiperOrigin-RevId: 738190879
This commit is contained in:
parent
3f91b4b43a
commit
663ef7ae01
@ -33,7 +33,6 @@ from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.mesh import use_concrete_mesh
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
|
||||
device_replica_id_map, hashed_index, num_addressable_indices,
|
||||
local_to_global_shape, use_concrete_mesh) # pyformat: disable
|
||||
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
|
||||
import numpy as np
|
||||
|
@ -543,6 +543,10 @@ class UseAbstractMeshContextManager:
|
||||
__slots__ = ['mesh', 'prev']
|
||||
|
||||
def __init__(self, mesh: AbstractMesh):
|
||||
if not isinstance(mesh, AbstractMesh):
|
||||
raise ValueError(
|
||||
"Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
|
||||
f" {type(mesh)}")
|
||||
self.mesh = mesh
|
||||
|
||||
def __enter__(self):
|
||||
@ -557,13 +561,5 @@ def get_abstract_mesh():
|
||||
val = jax_config.abstract_mesh_context_manager.value
|
||||
return empty_abstract_mesh if val is None else val
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_concrete_mesh(mesh: Mesh | None):
|
||||
prev_val = jax_config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
jax_config.device_context.set_local(prev_val)
|
||||
|
||||
def get_concrete_mesh() -> Mesh | None:
|
||||
return jax_config.device_context.value
|
||||
|
@ -1387,8 +1387,7 @@ def use_mesh(mesh: mesh_lib.Mesh):
|
||||
# if not core.trace_state_clean():
|
||||
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
|
||||
|
||||
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
|
||||
mesh_lib.use_concrete_mesh(mesh)):
|
||||
with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh):
|
||||
yield
|
||||
|
||||
def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
|
||||
@ -1408,3 +1407,18 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
|
||||
prev_mesh = config.device_context.get_global()
|
||||
config.device_context.set_global(mesh)
|
||||
return prev_mesh
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_concrete_mesh(mesh: mesh_lib.Mesh | None):
|
||||
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
|
||||
raise ValueError(
|
||||
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
|
||||
# TODO(yashkatariya): Enable this.
|
||||
# if not core.trace_state_clean():
|
||||
# raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.')
|
||||
|
||||
prev_val = config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
config.device_context.set_local(prev_val)
|
||||
|
Loading…
x
Reference in New Issue
Block a user