mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a private _extremely_unsafe_enter_tracing_context
to enter abstractMesh into tracing context. This is a temporary workaround for internal use cases.
PiperOrigin-RevId: 682960902
This commit is contained in:
parent
e16fac67da
commit
c6f7316d43
@ -407,6 +407,11 @@ class AbstractMesh:
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
raise RuntimeError("AbstractMesh is not a context manager")
|
||||
|
||||
@staticmethod
|
||||
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
|
||||
jax_config.update_thread_local_jit_state(mesh_context_manager=mesh)
|
||||
return
|
||||
|
||||
|
||||
# Create this indirection because pytype fails to recognize a property if a
|
||||
# property raises an exception unconditionally. Remove this once that is fixed.
|
||||
|
Loading…
x
Reference in New Issue
Block a user