Fix the broken behavior of not resetting the abstract_mesh and device_context properly during __exit__.

PiperOrigin-RevId: 702762477
This commit is contained in:
Yash Katariya 2024-12-04 09:58:45 -08:00 committed by jax authors
parent 681b9c2ebe
commit 653f65452d
3 changed files with 14 additions and 41 deletions

View File

@ -455,10 +455,15 @@ class AbstractMesh:
_raise_value_error("local_mesh")
def __enter__(self):
return push_abstract_mesh_context(self)
abstract_mesh_context.stack.append(self)
abstract_mesh_context.mesh = self
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
return self
def __exit__(self, exc_type, exc_value, traceback):
pop_abstract_mesh_context()
abstract_mesh_context.stack.pop()
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
return False
@staticmethod
@ -480,35 +485,6 @@ class AbstractMeshContext(threading.local):
abstract_mesh_context = AbstractMeshContext()
def push_abstract_mesh_context(val):
abstract_mesh_context.stack.append(val)
abstract_mesh_context.mesh = val
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
# Right now that leads to weird numerical issues.
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
if m is not None)
if non_none_meshes:
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
return val
def pop_abstract_mesh_context():
abstract_mesh_context.stack.pop()
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
if m is not None)
if non_none_meshes:
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
class null_mesh_context:
def __enter__(self):
return push_abstract_mesh_context(None)
def __exit__(self, *excinfo):
pop_abstract_mesh_context()
return False
@contextlib.contextmanager
def set_mesh(mesh: Mesh):
@ -529,14 +505,10 @@ device_context = DeviceContext()
def enter_device_context(mesh: Mesh):
device_context.stack.append(mesh)
device_context.concrete_mesh = mesh
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
if non_none_meshes:
jax_config.device_context.set_local(non_none_meshes)
jax_config.device_context.set_local(device_context.concrete_mesh)
try:
yield
finally:
device_context.stack.pop()
device_context.concrete_mesh = device_context.stack[-1]
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
if non_none_meshes:
jax_config.device_context.set_local(non_none_meshes)
jax_config.device_context.set_local(device_context.concrete_mesh)

View File

@ -16,6 +16,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence, Iterable
import contextlib
import dataclasses
from functools import partial
import inspect
@ -695,7 +696,7 @@ def _infer_params_impl(
def get_abstract_mesh(in_avals):
if not config.sharding_in_types.value:
return mesh_lib.null_mesh_context()
return contextlib.nullcontext()
m = None
for a in in_avals:
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
@ -708,7 +709,7 @@ def get_abstract_mesh(in_avals):
m = a.sharding.mesh # type: ignore
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if m is None:
return mesh_lib.null_mesh_context()
return contextlib.nullcontext()
assert isinstance(m, AbstractMesh)
return m

View File

@ -30,6 +30,7 @@ executable protocols described above.
"""
from __future__ import annotations
import contextlib
import functools
from collections.abc import Sequence
from dataclasses import dataclass
@ -43,7 +44,6 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src import mesh as mesh_lib
from jax._src.sharding_impls import UnspecifiedValue, AUTO
from jax._src.layout import Layout
from jax._src.interpreters import mlir
@ -717,7 +717,7 @@ class Traced(Stage):
"_args_flat", "_arg_names", "_num_consts"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable, abstract_mesh=mesh_lib.null_mesh_context(),
lower_callable, abstract_mesh=contextlib.nullcontext(),
args_flat=None, arg_names=None, num_consts: int = 0):
self.jaxpr = jaxpr
self.args_info = args_info