mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a set_mesh API to jax.sharding
. set_mesh
sets the sharding and never unsets it i.e. this is just __enter__
of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
This commit is contained in:
parent
8674495fd7
commit
47480b4493
@ -587,9 +587,3 @@ def set_concrete_mesh(mesh: Mesh | None):
|
||||
|
||||
def get_concrete_mesh():
|
||||
return jax_config.device_context.value
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_mesh(mesh: Mesh):
|
||||
with set_abstract_mesh(mesh.abstract_mesh), set_concrete_mesh(mesh):
|
||||
yield
|
||||
|
@ -690,7 +690,7 @@ def _infer_params(
|
||||
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
if ji.use_resource_env:
|
||||
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
|
||||
with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
|
||||
return _infer_params_internal(fun, ji, args, kwargs)
|
||||
return _infer_params_internal(fun, ji, args, kwargs)
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
from collections.abc import Mapping, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
@ -1410,3 +1411,28 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
|
||||
allow_split_physical_axes=allow_split_physical_axes)
|
||||
axis_types = _get_axis_types(auto_axes, explicit_axes, manual_axes)
|
||||
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_mesh(mesh: mesh_lib.Mesh):
|
||||
if not isinstance(mesh, mesh_lib.Mesh):
|
||||
raise ValueError(
|
||||
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
|
||||
|
||||
# TODO(yashkatariya): Enable this.
|
||||
# if not core.trace_state_clean():
|
||||
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
|
||||
|
||||
with (mesh_lib.set_abstract_mesh(mesh.abstract_mesh),
|
||||
mesh_lib.set_concrete_mesh(mesh)):
|
||||
yield
|
||||
|
||||
def set_mesh(mesh: mesh_lib.Mesh) -> None:
|
||||
if not isinstance(mesh, mesh_lib.Mesh):
|
||||
raise ValueError(
|
||||
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
|
||||
if not core.trace_state_clean():
|
||||
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
|
||||
|
||||
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
|
||||
config.device_context.set_local(mesh)
|
||||
|
@ -1576,7 +1576,7 @@ def with_user_mesh(sizes, names, axis_types=None):
|
||||
def decorator(fn):
|
||||
def mesh_fn(*args, **kwargs):
|
||||
mesh = create_mesh(sizes, names, axis_types=axis_types)
|
||||
with mesh_lib.use_mesh(mesh):
|
||||
with jax.sharding.use_mesh(mesh):
|
||||
return fn(*args, **kwargs, mesh=mesh)
|
||||
return mesh_fn
|
||||
return decorator
|
||||
|
@ -22,6 +22,8 @@ from jax._src.sharding_impls import (
|
||||
PmapSharding as PmapSharding,
|
||||
GSPMDSharding as GSPMDSharding,
|
||||
PositionalSharding as PositionalSharding,
|
||||
use_mesh as use_mesh,
|
||||
set_mesh as set_mesh,
|
||||
)
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
@ -30,7 +32,6 @@ from jax._src.interpreters.pxla import Mesh as Mesh
|
||||
from jax._src.mesh import (
|
||||
AbstractMesh as AbstractMesh,
|
||||
AxisTypes as AxisTypes,
|
||||
use_mesh as use_mesh
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
|
@ -5943,7 +5943,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
a = z @ x2
|
||||
return a
|
||||
|
||||
with mesh_lib.use_mesh(mesh):
|
||||
with jax.sharding.use_mesh(mesh):
|
||||
out = f(arr, arr.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
|
||||
lowered_text = f.lower(arr, arr.T).as_text()
|
||||
@ -5952,7 +5952,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Explicit: 'x',
|
||||
mesh_lib.AxisTypes.Auto: 'y'})
|
||||
with mesh_lib.use_mesh(mesh2):
|
||||
with jax.sharding.use_mesh(mesh2):
|
||||
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
|
||||
out = f(arr, arr2)
|
||||
@ -5966,7 +5966,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Explicit: 'y',
|
||||
mesh_lib.AxisTypes.Auto: 'x'})
|
||||
with mesh_lib.use_mesh(mesh3):
|
||||
with jax.sharding.use_mesh(mesh3):
|
||||
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P(None, 'x')))
|
||||
out = f(arr, arr2)
|
||||
@ -6143,7 +6143,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
auto_mesh = jax.make_mesh((2,), 'x', auto_axes='x')
|
||||
with mesh_lib.use_mesh(auto_mesh):
|
||||
with jax.sharding.use_mesh(auto_mesh):
|
||||
arr2 = jnp.ones(8)
|
||||
self.assertDictEqual(arr2.sharding.mesh.axis_types,
|
||||
{AxisTypes.Auto: ('x',)})
|
||||
@ -7083,6 +7083,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(np.arange(8))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
def test_set_mesh(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',), axis_types={AxisTypes.Explicit: 'x'})
|
||||
prev_mesh = config.device_context.value
|
||||
prev_abstract_mesh = config.abstract_mesh_context_manager.value
|
||||
try:
|
||||
jax.sharding.set_mesh(mesh)
|
||||
|
||||
out = reshard(np.arange(8), P('x'))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
finally:
|
||||
config.device_context.set_local(prev_mesh)
|
||||
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user