Merge pull request #11042 from hyeontaek:transfer-guard-doc-discoverability

PiperOrigin-RevId: 453994810
This commit is contained in:
jax authors 2022-06-09 13:12:10 -07:00
commit 859883cfae
2 changed files with 12 additions and 2 deletions

View File

@ -30,6 +30,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
pytrees
type_promotion
errors
transfer_guard
glossary
changelog
@ -59,7 +60,6 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
device_memory_profiling
rank_promotion_warning
custom_vjp_update
transfer_guard
.. toctree::
:maxdepth: 2

View File

@ -970,7 +970,17 @@ _transfer_guard = config.define_enum_state(
@contextlib.contextmanager
def transfer_guard(new_val: str) -> Iterator[None]:
"""Set up thread-local state and return a contextmanager for managing it."""
"""A contextmanager to control the transfer guard level for all transfers.
For more information, see
https://jax.readthedocs.io/en/latest/transfer_guard.html
Args:
new_val: The new thread-local transfer guard level for all transfers.
Yields:
None.
"""
with contextlib.ExitStack() as stack:
stack.enter_context(transfer_guard_host_to_device(new_val))
stack.enter_context(transfer_guard_device_to_device(new_val))