Add a documentation for the JAX transfer guard

This commit is contained in:
Hyeontaek Lim 2022-03-23 21:37:52 +00:00
parent f57e78e240
commit 87671f4d85
3 changed files with 74 additions and 1 deletions

View File

@ -58,6 +58,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
device_memory_profiling
rank_promotion_warning
custom_vjp_update
transfer_guard
.. toctree::
:maxdepth: 2

View File

@ -17,4 +17,4 @@ JAX configuration
default_matmul_precision
default_prng_impl
numpy_rank_promotion
transfer_guard

72
docs/transfer_guard.rst Normal file
View File

@ -0,0 +1,72 @@
Transfer guard
==============
JAX may transfer data between the host and devices and between devices during
type conversion and input sharding. To log or disallow any unintended
transfers, the user may configure a JAX transfer guard.
JAX transfer guards distinguish between two types of transfers:
* Explicit transfers: ``jax.device_put*()`` and ``jax.device_get()`` calls.
* Implicit transfers: Other transfers (e.g., printing a ``DeviceArray``).
A transfer guard can take an action based on its guard level:
* ``"allow"``: Silently allow all transfers (default).
* ``"log"``: Log and allow implicit transfers. Silently allow explicit
transfers.
* ``"disallow"``: Disallow implicit transfers. Silently allow explicit
transfers.
* ``"log_explicit"``: Log and allow all transfers.
* ``"disallow_explicit"``: Disallow all transfers.
JAX will raise a ``RuntimeError`` when disallowing a transfer.
The transfer guards use the standard JAX configuration system:
* A ``--jax_transfer_guard=GUARD_LEVEL`` command-line flag and
``jax.config.update("jax_transfer_guard", GUARD_LEVEL)`` will set the global
option.
* A ``with jax.transfer_guard(GUARD_LEVEL): ...`` context manager will set the
thread-local option within the scope of the context manager.
Note that similar to other JAX configuration options, a newly spawned thread
will use the global option instead of any active thread-local option of the
scope where the thread was spawned.
The transfer guards can also be applied more selectively, based on the
direction of transfer. The flag and context manager name is suffixed with a
corresponding transfer direction (e.g., ``--jax_transfer_guard_host_to_device``
and ``jax.config.transfer_guard_host_to_device``):
* ``"host_to_device"``: Converting a Python value or NumPy array into a JAX
on-device buffer.
* ``"device_to_device"``: Copying a JAX on-device buffer to a different device.
* ``"device_to_host"``: Fetching a JAX on-device buffer.
Fetching a buffer on a CPU device is always allowed regardless of the transfer
guard level.
The following shows an example of using the transfer guard.
.. code-block:: python
>>> jax.config.update("jax_transfer_guard", "allow") # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x) # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
... print("x", x) # x has already been fetched into the host.
... print("y", jax.device_get(y)) # Explicit transfers are allowed.
... try:
... print("z", z) # Implicit transfers are disallowed.
... assert False, "This line is expected to be unreachable."
... except:
... print("z could not be fetched") # doctest: +SKIP
x 1
y 2
z could not be fetched