Create a null_mesh_context internal context manager to handle null contexts properly.

PiperOrigin-RevId: 700167406
This commit is contained in:
Yash Katariya 2024-11-25 18:29:52 -08:00 committed by jax authors
parent 59e13f8114
commit 627debc78b
3 changed files with 29 additions and 14 deletions

View File

@ -455,17 +455,10 @@ class AbstractMesh:
_raise_value_error("local_mesh")
def __enter__(self):
mesh_context.stack.append(self)
mesh_context.mesh = self
jax_config.abstract_mesh_context_manager.set_local(
tuple(m for m in mesh_context.stack if m is not None))
return self
return push_mesh_context(self)
def __exit__(self, exc_type, exc_value, traceback):
mesh_context.stack.pop()
mesh_context.mesh = mesh_context.stack[-1]
jax_config.abstract_mesh_context_manager.set_local(
tuple(m for m in mesh_context.stack if m is not None))
pop_mesh_context()
return False
@staticmethod
@ -486,3 +479,26 @@ class MeshContext(threading.local):
self.mesh = self.stack[-1]
mesh_context = MeshContext()
def push_mesh_context(val):
mesh_context.stack.append(val)
mesh_context.mesh = val
jax_config.abstract_mesh_context_manager.set_local(
tuple(m for m in mesh_context.stack if m is not None))
return val
def pop_mesh_context():
mesh_context.stack.pop()
mesh_context.mesh = mesh_context.stack[-1]
jax_config.abstract_mesh_context_manager.set_local(
tuple(m for m in mesh_context.stack if m is not None))
class null_mesh_context:
def __enter__(self):
return push_mesh_context(None)
def __exit__(self, *excinfo):
pop_mesh_context()
return False

View File

@ -16,7 +16,6 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence, Iterable
import contextlib
import dataclasses
from functools import partial
import inspect
@ -696,7 +695,7 @@ def _infer_params_impl(
def get_abstract_mesh(in_avals):
if not config.sharding_in_types.value:
return contextlib.nullcontext()
return mesh_lib.null_mesh_context()
m = None
for a in in_avals:
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
@ -709,7 +708,7 @@ def get_abstract_mesh(in_avals):
m = a.sharding.mesh # type: ignore
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if m is None:
return contextlib.nullcontext()
return mesh_lib.null_mesh_context()
assert m is not None
return m

View File

@ -30,7 +30,6 @@ executable protocols described above.
"""
from __future__ import annotations
import contextlib
import functools
from collections.abc import Sequence
from dataclasses import dataclass
@ -44,6 +43,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src import mesh as mesh_lib
from jax._src.sharding_impls import UnspecifiedValue, AUTO
from jax._src.layout import Layout
from jax._src.interpreters import mlir
@ -717,7 +717,7 @@ class Traced(Stage):
"_args_flat", "_arg_names", "_num_consts"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable, abstract_mesh=contextlib.nullcontext(),
lower_callable, abstract_mesh=mesh_lib.null_mesh_context(),
args_flat=None, arg_names=None, num_consts: int = 0):
self.jaxpr = jaxpr
self.args_info = args_info