From 6763fcfb4e15cb8cb3260d713df82c836a08918d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 26 Nov 2024 10:48:12 -0800 Subject: [PATCH] Fix a weird interaction with `set_local` and empty tuples passed to it. PiperOrigin-RevId: 700392735 --- jax/_src/mesh.py | 12 ++++++++---- jax/_src/pjit.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 214fb190d..cecc24fb2 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -483,15 +483,19 @@ mesh_context = MeshContext() def push_mesh_context(val): mesh_context.stack.append(val) mesh_context.mesh = val - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them. + # Right now that leads to weird numerical issues. + non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + if non_none_meshes: + jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) return val def pop_mesh_context(): mesh_context.stack.pop() mesh_context.mesh = mesh_context.stack[-1] - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + if non_none_meshes: + jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) class null_mesh_context: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b77af1a8f..a2c0aff98 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals): # TODO(yashkatariya): Remove this when mesh context can be set by the user. if m is None: return mesh_lib.null_mesh_context() - assert m is not None + assert isinstance(m, AbstractMesh) return m