diff --git a/jax/_src/array.py b/jax/_src/array.py
index 2f10d8de8..b0793d2c3 100644
--- a/jax/_src/array.py
+++ b/jax/_src/array.py
@@ -33,7 +33,6 @@ from jax._src import errors
 from jax._src import profiler
 from jax._src import util
 from jax._src import xla_bridge
-from jax._src.mesh import use_concrete_mesh
 from jax._src.interpreters import mlir
 from jax._src.interpreters import pxla
 from jax._src.interpreters import xla
@@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe
 from jax._src.sharding import Sharding
 from jax._src.sharding_impls import (
     PmapSharding, SingleDeviceSharding,
-    device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape)  # pyformat: disable
+    device_replica_id_map, hashed_index, num_addressable_indices,
+    local_to_global_shape, use_concrete_mesh)  # pyformat: disable
 from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
 from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
 import numpy as np
diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py
index 4cb8ba0af..b490febf7 100644
--- a/jax/_src/mesh.py
+++ b/jax/_src/mesh.py
@@ -543,6 +543,10 @@ class UseAbstractMeshContextManager:
   __slots__ = ['mesh', 'prev']
 
   def __init__(self, mesh: AbstractMesh):
+    if not isinstance(mesh, AbstractMesh):
+      raise ValueError(
+          "Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
+          f" {type(mesh)}")
     self.mesh = mesh
 
   def __enter__(self):
@@ -557,13 +561,5 @@ def get_abstract_mesh():
   val = jax_config.abstract_mesh_context_manager.value
   return empty_abstract_mesh if val is None else val
 
-@contextlib.contextmanager
-def use_concrete_mesh(mesh: Mesh | None):
-  prev_val = jax_config.device_context.swap_local(mesh)
-  try:
-    yield
-  finally:
-    jax_config.device_context.set_local(prev_val)
-
 def get_concrete_mesh() -> Mesh | None:
   return jax_config.device_context.value
diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py
index 51c4ad639..2bbf91378 100644
--- a/jax/_src/sharding_impls.py
+++ b/jax/_src/sharding_impls.py
@@ -1387,8 +1387,7 @@ def use_mesh(mesh: mesh_lib.Mesh):
   # if not core.trace_state_clean():
   #   raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
 
-  with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
-        mesh_lib.use_concrete_mesh(mesh)):
+  with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh):
     yield
 
 def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
@@ -1408,3 +1407,18 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
   prev_mesh = config.device_context.get_global()
   config.device_context.set_global(mesh)
   return prev_mesh
+
+@contextlib.contextmanager
+def use_concrete_mesh(mesh: mesh_lib.Mesh | None):
+  if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
+    raise ValueError(
+        f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
+  # TODO(yashkatariya): Enable this.
+  # if not core.trace_state_clean():
+  #   raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.')
+
+  prev_val = config.device_context.swap_local(mesh)
+  try:
+    yield
+  finally:
+    config.device_context.set_local(prev_val)