mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix the broken behavior of not resetting the abstract_mesh and device_context properly during __exit__
.
PiperOrigin-RevId: 702762477
This commit is contained in:
parent
681b9c2ebe
commit
653f65452d
@ -455,10 +455,15 @@ class AbstractMesh:
|
||||
_raise_value_error("local_mesh")
|
||||
|
||||
def __enter__(self):
|
||||
return push_abstract_mesh_context(self)
|
||||
abstract_mesh_context.stack.append(self)
|
||||
abstract_mesh_context.mesh = self
|
||||
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pop_abstract_mesh_context()
|
||||
abstract_mesh_context.stack.pop()
|
||||
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
|
||||
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@ -480,35 +485,6 @@ class AbstractMeshContext(threading.local):
|
||||
|
||||
abstract_mesh_context = AbstractMeshContext()
|
||||
|
||||
def push_abstract_mesh_context(val):
|
||||
abstract_mesh_context.stack.append(val)
|
||||
abstract_mesh_context.mesh = val
|
||||
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
|
||||
# Right now that leads to weird numerical issues.
|
||||
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
|
||||
if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
|
||||
return val
|
||||
|
||||
def pop_abstract_mesh_context():
|
||||
abstract_mesh_context.stack.pop()
|
||||
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
|
||||
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
|
||||
if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
|
||||
|
||||
|
||||
class null_mesh_context:
|
||||
|
||||
def __enter__(self):
|
||||
return push_abstract_mesh_context(None)
|
||||
|
||||
def __exit__(self, *excinfo):
|
||||
pop_abstract_mesh_context()
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_mesh(mesh: Mesh):
|
||||
@ -529,14 +505,10 @@ device_context = DeviceContext()
|
||||
def enter_device_context(mesh: Mesh):
|
||||
device_context.stack.append(mesh)
|
||||
device_context.concrete_mesh = mesh
|
||||
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.device_context.set_local(non_none_meshes)
|
||||
jax_config.device_context.set_local(device_context.concrete_mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
device_context.stack.pop()
|
||||
device_context.concrete_mesh = device_context.stack[-1]
|
||||
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.device_context.set_local(non_none_meshes)
|
||||
jax_config.device_context.set_local(device_context.concrete_mesh)
|
||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence, Iterable
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
@ -695,7 +696,7 @@ def _infer_params_impl(
|
||||
|
||||
def get_abstract_mesh(in_avals):
|
||||
if not config.sharding_in_types.value:
|
||||
return mesh_lib.null_mesh_context()
|
||||
return contextlib.nullcontext()
|
||||
m = None
|
||||
for a in in_avals:
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
@ -708,7 +709,7 @@ def get_abstract_mesh(in_avals):
|
||||
m = a.sharding.mesh # type: ignore
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
if m is None:
|
||||
return mesh_lib.null_mesh_context()
|
||||
return contextlib.nullcontext()
|
||||
assert isinstance(m, AbstractMesh)
|
||||
return m
|
||||
|
||||
|
@ -30,6 +30,7 @@ executable protocols described above.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -43,7 +44,6 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.sharding_impls import UnspecifiedValue, AUTO
|
||||
from jax._src.layout import Layout
|
||||
from jax._src.interpreters import mlir
|
||||
@ -717,7 +717,7 @@ class Traced(Stage):
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable, abstract_mesh=mesh_lib.null_mesh_context(),
|
||||
lower_callable, abstract_mesh=contextlib.nullcontext(),
|
||||
args_flat=None, arg_names=None, num_consts: int = 0):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
|
Loading…
x
Reference in New Issue
Block a user