From 663ef7ae0120fe0b91cb32bb0ad8b1ae5b847f12 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Mar 2025 16:56:50 -0700 Subject: [PATCH] Check the type of mesh in `use_abstract_mesh` and `use_concrete_mesh` PiperOrigin-RevId: 738190879 --- jax/_src/array.py | 4 ++-- jax/_src/mesh.py | 12 ++++-------- jax/_src/sharding_impls.py | 18 ++++++++++++++++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 2f10d8de8..b0793d2c3 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -33,7 +33,6 @@ from jax._src import errors from jax._src import profiler from jax._src import util from jax._src import xla_bridge -from jax._src.mesh import use_concrete_mesh from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, - device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable + device_replica_id_map, hashed_index, num_addressable_indices, + local_to_global_shape, use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 4cb8ba0af..b490febf7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -543,6 +543,10 @@ class UseAbstractMeshContextManager: __slots__ = ['mesh', 'prev'] def __init__(self, mesh: AbstractMesh): + if not isinstance(mesh, AbstractMesh): + raise ValueError( + "Expected mesh of type `jax.sharding.AbstractMesh`. Got type:" + f" {type(mesh)}") self.mesh = mesh def __enter__(self): @@ -557,13 +561,5 @@ def get_abstract_mesh(): val = jax_config.abstract_mesh_context_manager.value return empty_abstract_mesh if val is None else val -@contextlib.contextmanager -def use_concrete_mesh(mesh: Mesh | None): - prev_val = jax_config.device_context.swap_local(mesh) - try: - yield - finally: - jax_config.device_context.set_local(prev_val) - def get_concrete_mesh() -> Mesh | None: return jax_config.device_context.value diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 51c4ad639..2bbf91378 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1387,8 +1387,7 @@ def use_mesh(mesh: mesh_lib.Mesh): # if not core.trace_state_clean(): # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') - with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh), - mesh_lib.use_concrete_mesh(mesh)): + with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): yield def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: @@ -1408,3 +1407,18 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: prev_mesh = config.device_context.get_global() config.device_context.set_global(mesh) return prev_mesh + +@contextlib.contextmanager +def use_concrete_mesh(mesh: mesh_lib.Mesh | None): + if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): + raise ValueError( + f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") + # TODO(yashkatariya): Enable this. + # if not core.trace_state_clean(): + # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + + prev_val = config.device_context.swap_local(mesh) + try: + yield + finally: + config.device_context.set_local(prev_val)