Add a private _extremely_unsafe_enter_tracing_context to enter abstractMesh into tracing context. This is a temporary workaround for internal use cases.

PiperOrigin-RevId: 682960902
This commit is contained in:
Yash Katariya 2024-10-06 14:49:43 -07:00 committed by jax authors
parent e16fac67da
commit c6f7316d43

View File

@ -407,6 +407,11 @@ class AbstractMesh:
def __exit__(self, exc_type, exc_value, traceback):
raise RuntimeError("AbstractMesh is not a context manager")
@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
jax_config.update_thread_local_jit_state(mesh_context_manager=mesh)
return
# Create this indirection because pytype fails to recognize a property if a
# property raises an exception unconditionally. Remove this once that is fixed.