Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py

PiperOrigin-RevId: 702852769
This commit is contained in:
Yash Katariya 2024-12-04 14:03:45 -08:00 committed by jax authors
parent 1a3c9c44dc
commit a735bf83e5
7 changed files with 47 additions and 58 deletions

View File

@ -1079,8 +1079,9 @@ else:
def __init__(self, name):
self._name = name
@property
def value(self):
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
return self.get_local()
def get_local(self):
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
@ -1088,6 +1089,11 @@ else:
def set_local(self, value):
update_thread_local_jit_state(**{self._name: value})
def swap_local(self, new_value):
prev_value = self.value
self.set_local(new_value)
return prev_value
trace_state = JitConfig('trace_state')
axis_env_state = JitConfig('axis_env_state')
mesh_context_manager = JitConfig('mesh_context_manager')

View File

@ -1605,10 +1605,10 @@ def get_sharding(sharding, ndim):
assert len(sharding.spec) == ndim
return sharding
context_mesh = mesh_lib.abstract_mesh_context.mesh
context_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
# code.
if context_mesh is None:
if not context_mesh:
return None
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))

View File

@ -454,18 +454,6 @@ class AbstractMesh:
def local_mesh(self):
_raise_value_error("local_mesh")
def __enter__(self):
abstract_mesh_context.stack.append(self)
abstract_mesh_context.mesh = self
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
return self
def __exit__(self, exc_type, exc_value, traceback):
abstract_mesh_context.stack.pop()
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
return False
@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
jax_config.abstract_mesh_context_manager.set_local(mesh)
@ -478,37 +466,32 @@ def _raise_value_error(name):
raise ValueError(f"AbstractMesh does not implement {name}")
class AbstractMeshContext(threading.local):
def __init__(self):
self.stack = [None]
self.mesh = self.stack[-1]
@contextlib.contextmanager
def set_abstract_mesh(mesh: AbstractMesh):
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
try:
yield
finally:
jax_config.abstract_mesh_context_manager.set_local(prev_val)
abstract_mesh_context = AbstractMeshContext()
def get_abstract_mesh():
return jax_config.abstract_mesh_context_manager.value
@contextlib.contextmanager
def set_concrete_mesh(mesh: Mesh):
prev_val = jax_config.device_context.swap_local(mesh)
try:
yield
finally:
jax_config.device_context.set_local(prev_val)
def get_concrete_mesh():
return jax_config.device_context.value
@contextlib.contextmanager
def set_mesh(mesh: Mesh):
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
enter_device_context(mesh)):
with (set_abstract_mesh(mesh.abstract_mesh),
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
yield
class DeviceContext(threading.local):
def __init__(self):
self.stack = [None]
self.concrete_mesh = self.stack[-1]
device_context = DeviceContext()
@contextlib.contextmanager
def enter_device_context(mesh: Mesh):
device_context.stack.append(mesh)
device_context.concrete_mesh = mesh
jax_config.device_context.set_local(device_context.concrete_mesh)
try:
yield
finally:
device_context.stack.pop()
device_context.concrete_mesh = device_context.stack[-1]
jax_config.device_context.set_local(device_context.concrete_mesh)

View File

@ -16,7 +16,6 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence, Iterable
import contextlib
import dataclasses
from functools import partial
import inspect
@ -187,7 +186,7 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
try:
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
# and set the context manager down the stack?
with p.abstract_mesh:
with mesh_lib.set_abstract_mesh(p.abstract_mesh):
if (core.trace_state_clean() and
not config.debug_key_reuse.value and
not config.data_dependent_tracing_fallback.value):
@ -645,9 +644,9 @@ def _infer_params_impl(
attr_token = _attr_token(flat_fun, in_type)
abstract_mesh = (
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
else mesh_lib.abstract_mesh_context.mesh)
with abstract_mesh:
get_abstract_mesh_from_avals(in_type)
if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh())
with mesh_lib.set_abstract_mesh(abstract_mesh):
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
@ -694,9 +693,9 @@ def _infer_params_impl(
attrs_tracked, abstract_mesh), args_flat
def get_abstract_mesh(in_avals):
def get_abstract_mesh_from_avals(in_avals):
if not config.sharding_in_types.value:
return contextlib.nullcontext()
return None
m = None
for a in in_avals:
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
@ -1789,7 +1788,8 @@ def _pjit_lower(
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
if config.sharding_in_types.value:
mesh = mesh_lib.device_context.concrete_mesh
cur_mesh = mesh_lib.get_concrete_mesh()
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
api_name = 'jit'
else:
mesh, api_name = ((resource_env.physical_mesh, 'pjit')

View File

@ -30,7 +30,6 @@ executable protocols described above.
"""
from __future__ import annotations
import contextlib
import functools
from collections.abc import Sequence
from dataclasses import dataclass
@ -44,6 +43,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src import mesh as mesh_lib
from jax._src.sharding_impls import UnspecifiedValue, AUTO
from jax._src.layout import Layout
from jax._src.interpreters import mlir
@ -717,7 +717,7 @@ class Traced(Stage):
"_args_flat", "_arg_names", "_num_consts"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable, abstract_mesh=contextlib.nullcontext(),
lower_callable, abstract_mesh=None,
args_flat=None, arg_names=None, num_consts: int = 0):
self.jaxpr = jaxpr
self.args_info = args_info
@ -747,7 +747,7 @@ class Traced(Stage):
try:
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
# and set the context manager down the stack?
with self._abstract_mesh:
with mesh_lib.set_abstract_mesh(self._abstract_mesh):
lowering = new_callable()
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args

View File

@ -46,7 +46,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
@ -484,7 +484,7 @@ def _shard_map_staging(
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
with (core.extend_axis_env_nd(list(mesh.shape.items())),
pjit.get_abstract_mesh(in_avals_)):
set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))):
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
_check_names(out_names_thunk(), out_avals_)
if check_rep:

View File

@ -5237,7 +5237,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def g(x, y):
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
return x * y
@jax.jit
@ -5262,7 +5262,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def g(x, y):
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
z = x @ allgatherd_y
return jax.lax.psum(z, axis_name='y')