Fix a weird interaction with set_local and empty tuples passed to it.

PiperOrigin-RevId: 700392735
This commit is contained in:
Yash Katariya 2024-11-26 10:48:12 -08:00 committed by jax authors
parent e453fa179e
commit 6763fcfb4e
2 changed files with 9 additions and 5 deletions

View File

@ -483,15 +483,19 @@ mesh_context = MeshContext()
def push_mesh_context(val):
mesh_context.stack.append(val)
mesh_context.mesh = val
jax_config.abstract_mesh_context_manager.set_local(
tuple(m for m in mesh_context.stack if m is not None))
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
# Right now that leads to weird numerical issues.
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
if non_none_meshes:
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
return val
def pop_mesh_context():
mesh_context.stack.pop()
mesh_context.mesh = mesh_context.stack[-1]
jax_config.abstract_mesh_context_manager.set_local(
tuple(m for m in mesh_context.stack if m is not None))
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
if non_none_meshes:
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
class null_mesh_context:

View File

@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals):
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if m is None:
return mesh_lib.null_mesh_context()
assert m is not None
assert isinstance(m, AbstractMesh)
return m