Document cudaMallocAsync as an experimental feature.

This commit is contained in:
Frederic Bastien 2024-11-21 13:25:07 -05:00
parent e707edeafa
commit a13b618c98

View File

@ -70,3 +70,31 @@ Common causes of OOM failures
memory. Note however, that the algorithm is basic and you can often get better
trade-off between compute and memory by disabling the automatic remat pass and doing
it manually with `the jax.remat API <https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html>`_
Experimental features
---------------------
Features here are experimental and must be tried with caution.
``TF_GPU_ALLOCATOR=cuda_malloc_async``
This replace XLA's own BFC memory allocator with `cudaMallocAsync
<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`_.
This will remove the big fixed pre-allocation and use a memory pool that grows.
The expected benefit is no need to set `XLA_PYTHON_CLIENT_MEM_FRACTION`.
The risk are:
- that memory fragmentation is different, so if you are close to the
limit, the exact OOM case due to fragmentation will be different.
- The allocation time won't be all paid at the start, but be incurred
when the memory pool need to be increased. So you could
experience less speed stability at the start and for benchmarks
it will be even more important to ignore the first few iterations.
The risks can be mitigated by pre-allocating a signigicant chunk and
still get the benefit of having a growing memory pool. This can be
done with `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N`. If N is `-1`
it will preallocate the same as what was allocatedy by
default. Otherwise, it is the size in bytes that you want to
preallocate.