mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix a weird interaction with set_local
and empty tuples passed to it.
PiperOrigin-RevId: 700392735
This commit is contained in:
parent
e453fa179e
commit
6763fcfb4e
@ -483,15 +483,19 @@ mesh_context = MeshContext()
|
||||
def push_mesh_context(val):
|
||||
mesh_context.stack.append(val)
|
||||
mesh_context.mesh = val
|
||||
jax_config.abstract_mesh_context_manager.set_local(
|
||||
tuple(m for m in mesh_context.stack if m is not None))
|
||||
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
|
||||
# Right now that leads to weird numerical issues.
|
||||
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
|
||||
return val
|
||||
|
||||
def pop_mesh_context():
|
||||
mesh_context.stack.pop()
|
||||
mesh_context.mesh = mesh_context.stack[-1]
|
||||
jax_config.abstract_mesh_context_manager.set_local(
|
||||
tuple(m for m in mesh_context.stack if m is not None))
|
||||
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
|
||||
|
||||
|
||||
class null_mesh_context:
|
||||
|
@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals):
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
if m is None:
|
||||
return mesh_lib.null_mesh_context()
|
||||
assert m is not None
|
||||
assert isinstance(m, AbstractMesh)
|
||||
return m
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user