mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
This commit is contained in:
parent
1a3c9c44dc
commit
a735bf83e5
@ -1079,8 +1079,9 @@ else:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
|
||||
return self.get_local()
|
||||
|
||||
def get_local(self):
|
||||
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
|
||||
@ -1088,6 +1089,11 @@ else:
|
||||
def set_local(self, value):
|
||||
update_thread_local_jit_state(**{self._name: value})
|
||||
|
||||
def swap_local(self, new_value):
|
||||
prev_value = self.value
|
||||
self.set_local(new_value)
|
||||
return prev_value
|
||||
|
||||
trace_state = JitConfig('trace_state')
|
||||
axis_env_state = JitConfig('axis_env_state')
|
||||
mesh_context_manager = JitConfig('mesh_context_manager')
|
||||
|
@ -1605,10 +1605,10 @@ def get_sharding(sharding, ndim):
|
||||
assert len(sharding.spec) == ndim
|
||||
return sharding
|
||||
|
||||
context_mesh = mesh_lib.abstract_mesh_context.mesh
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
|
||||
# code.
|
||||
if context_mesh is None:
|
||||
if not context_mesh:
|
||||
return None
|
||||
assert sharding is None
|
||||
return NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
|
@ -454,18 +454,6 @@ class AbstractMesh:
|
||||
def local_mesh(self):
|
||||
_raise_value_error("local_mesh")
|
||||
|
||||
def __enter__(self):
|
||||
abstract_mesh_context.stack.append(self)
|
||||
abstract_mesh_context.mesh = self
|
||||
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
abstract_mesh_context.stack.pop()
|
||||
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
|
||||
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
|
||||
jax_config.abstract_mesh_context_manager.set_local(mesh)
|
||||
@ -478,37 +466,32 @@ def _raise_value_error(name):
|
||||
raise ValueError(f"AbstractMesh does not implement {name}")
|
||||
|
||||
|
||||
class AbstractMeshContext(threading.local):
|
||||
def __init__(self):
|
||||
self.stack = [None]
|
||||
self.mesh = self.stack[-1]
|
||||
@contextlib.contextmanager
|
||||
def set_abstract_mesh(mesh: AbstractMesh):
|
||||
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
jax_config.abstract_mesh_context_manager.set_local(prev_val)
|
||||
|
||||
abstract_mesh_context = AbstractMeshContext()
|
||||
def get_abstract_mesh():
|
||||
return jax_config.abstract_mesh_context_manager.value
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_concrete_mesh(mesh: Mesh):
|
||||
prev_val = jax_config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
jax_config.device_context.set_local(prev_val)
|
||||
|
||||
def get_concrete_mesh():
|
||||
return jax_config.device_context.value
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_mesh(mesh: Mesh):
|
||||
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
|
||||
enter_device_context(mesh)):
|
||||
with (set_abstract_mesh(mesh.abstract_mesh),
|
||||
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
|
||||
yield
|
||||
|
||||
|
||||
class DeviceContext(threading.local):
|
||||
def __init__(self):
|
||||
self.stack = [None]
|
||||
self.concrete_mesh = self.stack[-1]
|
||||
|
||||
device_context = DeviceContext()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enter_device_context(mesh: Mesh):
|
||||
device_context.stack.append(mesh)
|
||||
device_context.concrete_mesh = mesh
|
||||
jax_config.device_context.set_local(device_context.concrete_mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
device_context.stack.pop()
|
||||
device_context.concrete_mesh = device_context.stack[-1]
|
||||
jax_config.device_context.set_local(device_context.concrete_mesh)
|
||||
|
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence, Iterable
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
@ -187,7 +186,7 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
|
||||
try:
|
||||
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
|
||||
# and set the context manager down the stack?
|
||||
with p.abstract_mesh:
|
||||
with mesh_lib.set_abstract_mesh(p.abstract_mesh):
|
||||
if (core.trace_state_clean() and
|
||||
not config.debug_key_reuse.value and
|
||||
not config.data_dependent_tracing_fallback.value):
|
||||
@ -645,9 +644,9 @@ def _infer_params_impl(
|
||||
attr_token = _attr_token(flat_fun, in_type)
|
||||
|
||||
abstract_mesh = (
|
||||
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
|
||||
else mesh_lib.abstract_mesh_context.mesh)
|
||||
with abstract_mesh:
|
||||
get_abstract_mesh_from_avals(in_type)
|
||||
if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh())
|
||||
with mesh_lib.set_abstract_mesh(abstract_mesh):
|
||||
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
HashableFunction(res_paths, closure=()),
|
||||
@ -694,9 +693,9 @@ def _infer_params_impl(
|
||||
attrs_tracked, abstract_mesh), args_flat
|
||||
|
||||
|
||||
def get_abstract_mesh(in_avals):
|
||||
def get_abstract_mesh_from_avals(in_avals):
|
||||
if not config.sharding_in_types.value:
|
||||
return contextlib.nullcontext()
|
||||
return None
|
||||
m = None
|
||||
for a in in_avals:
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
@ -1789,7 +1788,8 @@ def _pjit_lower(
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
if config.sharding_in_types.value:
|
||||
mesh = mesh_lib.device_context.concrete_mesh
|
||||
cur_mesh = mesh_lib.get_concrete_mesh()
|
||||
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
|
||||
api_name = 'jit'
|
||||
else:
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
|
@ -30,7 +30,6 @@ executable protocols described above.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -44,6 +43,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.sharding_impls import UnspecifiedValue, AUTO
|
||||
from jax._src.layout import Layout
|
||||
from jax._src.interpreters import mlir
|
||||
@ -717,7 +717,7 @@ class Traced(Stage):
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable, abstract_mesh=contextlib.nullcontext(),
|
||||
lower_callable, abstract_mesh=None,
|
||||
args_flat=None, arg_names=None, num_consts: int = 0):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
@ -747,7 +747,7 @@ class Traced(Stage):
|
||||
try:
|
||||
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
|
||||
# and set the context manager down the stack?
|
||||
with self._abstract_mesh:
|
||||
with mesh_lib.set_abstract_mesh(self._abstract_mesh):
|
||||
lowering = new_callable()
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
|
@ -46,7 +46,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes
|
||||
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
|
||||
from jax._src.api import _shared_code_pmap, _prepare_pmap
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, convolution, fft, linalg,
|
||||
@ -484,7 +484,7 @@ def _shard_map_staging(
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
with (core.extend_axis_env_nd(list(mesh.shape.items())),
|
||||
pjit.get_abstract_mesh(in_avals_)):
|
||||
set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))):
|
||||
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
|
@ -5237,7 +5237,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def g(x, y):
|
||||
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
|
||||
return x * y
|
||||
|
||||
@jax.jit
|
||||
@ -5262,7 +5262,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def g(x, y):
|
||||
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
|
||||
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
|
||||
z = x @ allgatherd_y
|
||||
return jax.lax.psum(z, axis_name='y')
|
||||
|
Loading…
x
Reference in New Issue
Block a user