Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__

PiperOrigin-RevId: 736261724
This commit is contained in:
Yash Katariya 2025-03-12 14:12:06 -07:00 committed by jax authors
parent 8674495fd7
commit 47480b4493
6 changed files with 47 additions and 13 deletions

View File

@ -587,9 +587,3 @@ def set_concrete_mesh(mesh: Mesh | None):
def get_concrete_mesh():
return jax_config.device_context.value
@contextlib.contextmanager
def use_mesh(mesh: Mesh):
with set_abstract_mesh(mesh.abstract_mesh), set_concrete_mesh(mesh):
yield

View File

@ -690,7 +690,7 @@ def _infer_params(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
if ji.use_resource_env:
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
return _infer_params_internal(fun, ji, args, kwargs)
return _infer_params_internal(fun, ji, args, kwargs)

View File

@ -15,6 +15,7 @@
from __future__ import annotations
import collections
import contextlib
from collections.abc import Mapping, Sequence
import dataclasses
import functools
@ -1410,3 +1411,28 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
allow_split_physical_axes=allow_split_physical_axes)
axis_types = _get_axis_types(auto_axes, explicit_axes, manual_axes)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
@contextlib.contextmanager
def use_mesh(mesh: mesh_lib.Mesh):
if not isinstance(mesh, mesh_lib.Mesh):
raise ValueError(
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
# TODO(yashkatariya): Enable this.
# if not core.trace_state_clean():
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
with (mesh_lib.set_abstract_mesh(mesh.abstract_mesh),
mesh_lib.set_concrete_mesh(mesh)):
yield
def set_mesh(mesh: mesh_lib.Mesh) -> None:
if not isinstance(mesh, mesh_lib.Mesh):
raise ValueError(
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
if not core.trace_state_clean():
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
config.device_context.set_local(mesh)

View File

@ -1576,7 +1576,7 @@ def with_user_mesh(sizes, names, axis_types=None):
def decorator(fn):
def mesh_fn(*args, **kwargs):
mesh = create_mesh(sizes, names, axis_types=axis_types)
with mesh_lib.use_mesh(mesh):
with jax.sharding.use_mesh(mesh):
return fn(*args, **kwargs, mesh=mesh)
return mesh_fn
return decorator

View File

@ -22,6 +22,8 @@ from jax._src.sharding_impls import (
PmapSharding as PmapSharding,
GSPMDSharding as GSPMDSharding,
PositionalSharding as PositionalSharding,
use_mesh as use_mesh,
set_mesh as set_mesh,
)
from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
@ -30,7 +32,6 @@ from jax._src.interpreters.pxla import Mesh as Mesh
from jax._src.mesh import (
AbstractMesh as AbstractMesh,
AxisTypes as AxisTypes,
use_mesh as use_mesh
)
_deprecations = {

View File

@ -5943,7 +5943,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
a = z @ x2
return a
with mesh_lib.use_mesh(mesh):
with jax.sharding.use_mesh(mesh):
out = f(arr, arr.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
lowered_text = f.lower(arr, arr.T).as_text()
@ -5952,7 +5952,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Explicit: 'x',
mesh_lib.AxisTypes.Auto: 'y'})
with mesh_lib.use_mesh(mesh2):
with jax.sharding.use_mesh(mesh2):
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
out = f(arr, arr2)
@ -5966,7 +5966,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Explicit: 'y',
mesh_lib.AxisTypes.Auto: 'x'})
with mesh_lib.use_mesh(mesh3):
with jax.sharding.use_mesh(mesh3):
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P(None, 'x')))
out = f(arr, arr2)
@ -6143,7 +6143,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
auto_mesh = jax.make_mesh((2,), 'x', auto_axes='x')
with mesh_lib.use_mesh(auto_mesh):
with jax.sharding.use_mesh(auto_mesh):
arr2 = jnp.ones(8)
self.assertDictEqual(arr2.sharding.mesh.axis_types,
{AxisTypes.Auto: ('x',)})
@ -7083,6 +7083,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(np.arange(8))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
def test_set_mesh(self):
mesh = jtu.create_mesh((2,), ('x',), axis_types={AxisTypes.Explicit: 'x'})
prev_mesh = config.device_context.value
prev_abstract_mesh = config.abstract_mesh_context_manager.value
try:
jax.sharding.set_mesh(mesh)
out = reshard(np.arange(8), P('x'))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
finally:
config.device_context.set_local(prev_mesh)
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):