mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Create a null_mesh_context
internal context manager to handle null contexts properly.
PiperOrigin-RevId: 700167406
This commit is contained in:
parent
59e13f8114
commit
627debc78b
@ -455,17 +455,10 @@ class AbstractMesh:
|
||||
_raise_value_error("local_mesh")
|
||||
|
||||
def __enter__(self):
|
||||
mesh_context.stack.append(self)
|
||||
mesh_context.mesh = self
|
||||
jax_config.abstract_mesh_context_manager.set_local(
|
||||
tuple(m for m in mesh_context.stack if m is not None))
|
||||
return self
|
||||
return push_mesh_context(self)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
mesh_context.stack.pop()
|
||||
mesh_context.mesh = mesh_context.stack[-1]
|
||||
jax_config.abstract_mesh_context_manager.set_local(
|
||||
tuple(m for m in mesh_context.stack if m is not None))
|
||||
pop_mesh_context()
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@ -486,3 +479,26 @@ class MeshContext(threading.local):
|
||||
self.mesh = self.stack[-1]
|
||||
|
||||
mesh_context = MeshContext()
|
||||
|
||||
def push_mesh_context(val):
|
||||
mesh_context.stack.append(val)
|
||||
mesh_context.mesh = val
|
||||
jax_config.abstract_mesh_context_manager.set_local(
|
||||
tuple(m for m in mesh_context.stack if m is not None))
|
||||
return val
|
||||
|
||||
def pop_mesh_context():
|
||||
mesh_context.stack.pop()
|
||||
mesh_context.mesh = mesh_context.stack[-1]
|
||||
jax_config.abstract_mesh_context_manager.set_local(
|
||||
tuple(m for m in mesh_context.stack if m is not None))
|
||||
|
||||
|
||||
class null_mesh_context:
|
||||
|
||||
def __enter__(self):
|
||||
return push_mesh_context(None)
|
||||
|
||||
def __exit__(self, *excinfo):
|
||||
pop_mesh_context()
|
||||
return False
|
||||
|
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence, Iterable
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
@ -696,7 +695,7 @@ def _infer_params_impl(
|
||||
|
||||
def get_abstract_mesh(in_avals):
|
||||
if not config.sharding_in_types.value:
|
||||
return contextlib.nullcontext()
|
||||
return mesh_lib.null_mesh_context()
|
||||
m = None
|
||||
for a in in_avals:
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
@ -709,7 +708,7 @@ def get_abstract_mesh(in_avals):
|
||||
m = a.sharding.mesh # type: ignore
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
if m is None:
|
||||
return contextlib.nullcontext()
|
||||
return mesh_lib.null_mesh_context()
|
||||
assert m is not None
|
||||
return m
|
||||
|
||||
|
@ -30,7 +30,6 @@ executable protocols described above.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -44,6 +43,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.sharding_impls import UnspecifiedValue, AUTO
|
||||
from jax._src.layout import Layout
|
||||
from jax._src.interpreters import mlir
|
||||
@ -717,7 +717,7 @@ class Traced(Stage):
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable, abstract_mesh=contextlib.nullcontext(),
|
||||
lower_callable, abstract_mesh=mesh_lib.null_mesh_context(),
|
||||
args_flat=None, arg_names=None, num_consts: int = 0):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
|
Loading…
x
Reference in New Issue
Block a user