diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 214fb190d..cecc24fb2 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -483,15 +483,19 @@ mesh_context = MeshContext() def push_mesh_context(val): mesh_context.stack.append(val) mesh_context.mesh = val - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them. + # Right now that leads to weird numerical issues. + non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + if non_none_meshes: + jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) return val def pop_mesh_context(): mesh_context.stack.pop() mesh_context.mesh = mesh_context.stack[-1] - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + if non_none_meshes: + jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) class null_mesh_context: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b77af1a8f..a2c0aff98 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals): # TODO(yashkatariya): Remove this when mesh context can be set by the user. if m is None: return mesh_lib.null_mesh_context() - assert m is not None + assert isinstance(m, AbstractMesh) return m