mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add a documentation for the JAX transfer guard
This commit is contained in:
parent
f57e78e240
commit
87671f4d85
@ -58,6 +58,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
device_memory_profiling
|
||||
rank_promotion_warning
|
||||
custom_vjp_update
|
||||
transfer_guard
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
@ -17,4 +17,4 @@ JAX configuration
|
||||
default_matmul_precision
|
||||
default_prng_impl
|
||||
numpy_rank_promotion
|
||||
|
||||
transfer_guard
|
||||
|
72
docs/transfer_guard.rst
Normal file
72
docs/transfer_guard.rst
Normal file
@ -0,0 +1,72 @@
|
||||
Transfer guard
|
||||
==============
|
||||
|
||||
JAX may transfer data between the host and devices and between devices during
|
||||
type conversion and input sharding. To log or disallow any unintended
|
||||
transfers, the user may configure a JAX transfer guard.
|
||||
|
||||
JAX transfer guards distinguish between two types of transfers:
|
||||
|
||||
* Explicit transfers: ``jax.device_put*()`` and ``jax.device_get()`` calls.
|
||||
* Implicit transfers: Other transfers (e.g., printing a ``DeviceArray``).
|
||||
|
||||
A transfer guard can take an action based on its guard level:
|
||||
|
||||
* ``"allow"``: Silently allow all transfers (default).
|
||||
* ``"log"``: Log and allow implicit transfers. Silently allow explicit
|
||||
transfers.
|
||||
* ``"disallow"``: Disallow implicit transfers. Silently allow explicit
|
||||
transfers.
|
||||
* ``"log_explicit"``: Log and allow all transfers.
|
||||
* ``"disallow_explicit"``: Disallow all transfers.
|
||||
|
||||
JAX will raise a ``RuntimeError`` when disallowing a transfer.
|
||||
|
||||
The transfer guards use the standard JAX configuration system:
|
||||
|
||||
* A ``--jax_transfer_guard=GUARD_LEVEL`` command-line flag and
|
||||
``jax.config.update("jax_transfer_guard", GUARD_LEVEL)`` will set the global
|
||||
option.
|
||||
* A ``with jax.transfer_guard(GUARD_LEVEL): ...`` context manager will set the
|
||||
thread-local option within the scope of the context manager.
|
||||
|
||||
Note that similar to other JAX configuration options, a newly spawned thread
|
||||
will use the global option instead of any active thread-local option of the
|
||||
scope where the thread was spawned.
|
||||
|
||||
The transfer guards can also be applied more selectively, based on the
|
||||
direction of transfer. The flag and context manager name is suffixed with a
|
||||
corresponding transfer direction (e.g., ``--jax_transfer_guard_host_to_device``
|
||||
and ``jax.config.transfer_guard_host_to_device``):
|
||||
|
||||
* ``"host_to_device"``: Converting a Python value or NumPy array into a JAX
|
||||
on-device buffer.
|
||||
* ``"device_to_device"``: Copying a JAX on-device buffer to a different device.
|
||||
* ``"device_to_host"``: Fetching a JAX on-device buffer.
|
||||
|
||||
Fetching a buffer on a CPU device is always allowed regardless of the transfer
|
||||
guard level.
|
||||
|
||||
The following shows an example of using the transfer guard.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>>> jax.config.update("jax_transfer_guard", "allow") # This is default.
|
||||
>>>
|
||||
>>> x = jnp.array(1)
|
||||
>>> y = jnp.array(2)
|
||||
>>> z = jnp.array(3)
|
||||
>>>
|
||||
>>> print("x", x) # All transfers are allowed.
|
||||
x 1
|
||||
>>> with jax.transfer_guard("disallow"):
|
||||
... print("x", x) # x has already been fetched into the host.
|
||||
... print("y", jax.device_get(y)) # Explicit transfers are allowed.
|
||||
... try:
|
||||
... print("z", z) # Implicit transfers are disallowed.
|
||||
... assert False, "This line is expected to be unreachable."
|
||||
... except:
|
||||
... print("z could not be fetched") # doctest: +SKIP
|
||||
x 1
|
||||
y 2
|
||||
z could not be fetched
|
Loading…
x
Reference in New Issue
Block a user