From 4f8881539d6692a6e287aa34628a178fb1890bf6 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 9 Jun 2022 17:56:03 +0000 Subject: [PATCH] Make the transfer guard documentation easier to find Move the main documentation of the transfer guard from "Notes" to "Reference documentation" section for better visibility. Add a link to the main documentation to the docstring of jax.transfer_guard(), which currently shows up as the top result when searching for "jax transfer_guard". --- docs/index.rst | 2 +- jax/_src/config.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 86d902fe8..51054aa4f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,6 +30,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more. pytrees type_promotion errors + transfer_guard glossary changelog @@ -59,7 +60,6 @@ parallelize, Just-In-Time compile to GPU/TPU, and more. device_memory_profiling rank_promotion_warning custom_vjp_update - transfer_guard .. toctree:: :maxdepth: 2 diff --git a/jax/_src/config.py b/jax/_src/config.py index 65917fb96..d66aa3036 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -970,7 +970,17 @@ _transfer_guard = config.define_enum_state( @contextlib.contextmanager def transfer_guard(new_val: str) -> Iterator[None]: - """Set up thread-local state and return a contextmanager for managing it.""" + """A contextmanager to control the transfer guard level for all transfers. + + For more information, see + https://jax.readthedocs.io/en/latest/transfer_guard.html + + Args: + new_val: The new thread-local transfer guard level for all transfers. + + Yields: + None. + """ with contextlib.ExitStack() as stack: stack.enter_context(transfer_guard_host_to_device(new_val)) stack.enter_context(transfer_guard_device_to_device(new_val))