Add docs on GPU memory allocation.

This commit is contained in:
Skye Wanderman-Milne 2019-07-29 12:24:58 -07:00 committed by Skye Wanderman-Milne
parent b9a984f438
commit 5b9849f177
2 changed files with 43 additions and 0 deletions

View File

@ -0,0 +1,42 @@
GPU memory allocation
=====================
**JAX will preallocate 90% of currently-available GPU memory when the first JAX
operation is run.** Preallocating minimizes allocation overhead and memory
fragmentation, but can sometimes cause out-of-memory (OOM) errors. If your JAX
process fails with OOM, the following environment variables can be used to
override the default behavior:
``XLA_PYTHON_CLIENT_PREALLOCATE=false``
This disables the preallocation behavior. JAX will instead allocate GPU
memory as needed, potentially decreasing the overall memory usage. However,
this behavior is more prone to GPU memory fragmentation, meaning a JAX program
that uses most of the available GPU memory may OOM with preallocation
disabled.
``XLA_PYTHON_CLIENT_MEM_FRACTION=.XX``
If preallocation is enabled, this makes JAX preallocate XX% of
currently-available GPU memory, instead of the default 90%. Lowering the
amount preallocated can fix OOMs that occur when the JAX program starts.
``XLA_PYTHON_CLIENT_ALLOCATOR=platform``
This makes JAX allocate exactly what is needed on demand. This is very slow,
so is not recommended for general use, but may be useful for debugging OOM
failures.
Common causes of OOM failures
-----------------------------
**Running multiple JAX processes concurrently.**
Either use XLA_PYTHON_CLIENT_MEM_FRACTION to give each process an appropriate
amount of memory, or set XLA_PYTHON_CLIENT_PREALLOCATE=false.
**Running JAX and GPU TensorFlow concurrently.**
TensorFlow also preallocates by default, so this is similar to running
multiple JAX processes concurrently. One solution is to use CPU-only
TensorFlow (e.g. if you're only doing data loading with TF). Alternatively,
use XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE.
**Running JAX on the display GPU.**
Use XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE.

View File

@ -12,6 +12,7 @@ For an introduction to JAX, start at the
:caption: Notes
async_dispatch
gpu_memory_allocation
.. toctree::
:maxdepth: 3