mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #11042 from hyeontaek:transfer-guard-doc-discoverability
PiperOrigin-RevId: 453994810
This commit is contained in:
commit
859883cfae
@ -30,6 +30,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
pytrees
|
||||
type_promotion
|
||||
errors
|
||||
transfer_guard
|
||||
glossary
|
||||
changelog
|
||||
|
||||
@ -59,7 +60,6 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
device_memory_profiling
|
||||
rank_promotion_warning
|
||||
custom_vjp_update
|
||||
transfer_guard
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
@ -970,7 +970,17 @@ _transfer_guard = config.define_enum_state(
|
||||
|
||||
@contextlib.contextmanager
|
||||
def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it."""
|
||||
"""A contextmanager to control the transfer guard level for all transfers.
|
||||
|
||||
For more information, see
|
||||
https://jax.readthedocs.io/en/latest/transfer_guard.html
|
||||
|
||||
Args:
|
||||
new_val: The new thread-local transfer guard level for all transfers.
|
||||
|
||||
Yields:
|
||||
None.
|
||||
"""
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(transfer_guard_host_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_device(new_val))
|
||||
|
Loading…
x
Reference in New Issue
Block a user