mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add docs on GPU memory allocation.
This commit is contained in:
parent
b9a984f438
commit
5b9849f177
42
docs/gpu_memory_allocation.rst
Normal file
42
docs/gpu_memory_allocation.rst
Normal file
@ -0,0 +1,42 @@
|
||||
GPU memory allocation
|
||||
=====================
|
||||
|
||||
**JAX will preallocate 90% of currently-available GPU memory when the first JAX
|
||||
operation is run.** Preallocating minimizes allocation overhead and memory
|
||||
fragmentation, but can sometimes cause out-of-memory (OOM) errors. If your JAX
|
||||
process fails with OOM, the following environment variables can be used to
|
||||
override the default behavior:
|
||||
|
||||
``XLA_PYTHON_CLIENT_PREALLOCATE=false``
|
||||
This disables the preallocation behavior. JAX will instead allocate GPU
|
||||
memory as needed, potentially decreasing the overall memory usage. However,
|
||||
this behavior is more prone to GPU memory fragmentation, meaning a JAX program
|
||||
that uses most of the available GPU memory may OOM with preallocation
|
||||
disabled.
|
||||
|
||||
``XLA_PYTHON_CLIENT_MEM_FRACTION=.XX``
|
||||
If preallocation is enabled, this makes JAX preallocate XX% of
|
||||
currently-available GPU memory, instead of the default 90%. Lowering the
|
||||
amount preallocated can fix OOMs that occur when the JAX program starts.
|
||||
|
||||
``XLA_PYTHON_CLIENT_ALLOCATOR=platform``
|
||||
This makes JAX allocate exactly what is needed on demand. This is very slow,
|
||||
so is not recommended for general use, but may be useful for debugging OOM
|
||||
failures.
|
||||
|
||||
|
||||
Common causes of OOM failures
|
||||
-----------------------------
|
||||
|
||||
**Running multiple JAX processes concurrently.**
|
||||
Either use XLA_PYTHON_CLIENT_MEM_FRACTION to give each process an appropriate
|
||||
amount of memory, or set XLA_PYTHON_CLIENT_PREALLOCATE=false.
|
||||
|
||||
**Running JAX and GPU TensorFlow concurrently.**
|
||||
TensorFlow also preallocates by default, so this is similar to running
|
||||
multiple JAX processes concurrently. One solution is to use CPU-only
|
||||
TensorFlow (e.g. if you're only doing data loading with TF). Alternatively,
|
||||
use XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE.
|
||||
|
||||
**Running JAX on the display GPU.**
|
||||
Use XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE.
|
@ -12,6 +12,7 @@ For an introduction to JAX, start at the
|
||||
:caption: Notes
|
||||
|
||||
async_dispatch
|
||||
gpu_memory_allocation
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
Loading…
x
Reference in New Issue
Block a user