2024-08-12 22:13:36 -07:00
|
|
|
# Profiling device memory
|
2020-06-26 17:09:09 -04:00
|
|
|
|
2024-06-21 14:50:02 -07:00
|
|
|
<!--* freshness: { reviewed: '2024-03-08' } *-->
|
2023-05-30 14:21:28 -07:00
|
|
|
|
|
|
|
```{note}
|
|
|
|
May 2023 update: we recommend using [Tensorboard
|
|
|
|
profiling](tensorboard-profiling) for device memory analysis. After taking a
|
|
|
|
profile, open the `memory_viewer` tab of the Tensorboard profiler for more
|
|
|
|
detailed and understandable device memory usage.
|
|
|
|
```
|
|
|
|
|
2024-08-12 13:52:27 -07:00
|
|
|
The JAX device memory profiler allows us to explore how and why JAX programs are
|
2020-06-26 17:09:09 -04:00
|
|
|
using GPU or TPU memory. For example, it can be used to:
|
|
|
|
|
|
|
|
* Figure out which arrays and executables are in GPU memory at a given time, or
|
|
|
|
* Track down memory leaks.
|
|
|
|
|
|
|
|
## Installation
|
|
|
|
|
|
|
|
The JAX device memory profiler emits output that can be interpreted using
|
|
|
|
pprof (<https://github.com/google/pprof>). Start by installing `pprof`,
|
|
|
|
by following its
|
|
|
|
[installation instructions](https://github.com/google/pprof#building-pprof).
|
|
|
|
At the time of writing, installing `pprof` requires first installing
|
2022-06-21 11:15:14 +02:00
|
|
|
[Go](https://golang.org/) of version 1.16+,
|
|
|
|
[Graphviz](http://www.graphviz.org/), and then running
|
2020-06-26 17:09:09 -04:00
|
|
|
|
|
|
|
```shell
|
2022-06-21 11:15:14 +02:00
|
|
|
go install github.com/google/pprof@latest
|
2020-06-26 17:09:09 -04:00
|
|
|
```
|
|
|
|
|
|
|
|
which installs `pprof` as `$GOPATH/bin/pprof`, where `GOPATH` defaults to
|
|
|
|
`~/go`.
|
|
|
|
|
|
|
|
```{note}
|
|
|
|
The version of `pprof` from <https://github.com/google/pprof> is not the same as
|
|
|
|
the older tool of the same name distributed as part of the `gperftools` package.
|
|
|
|
The `gperftools` version of `pprof` will not work with JAX.
|
|
|
|
```
|
|
|
|
|
|
|
|
## Understanding how a JAX program is using GPU or TPU memory
|
|
|
|
|
|
|
|
A common use of the device memory profiler is to figure out why a JAX program is
|
|
|
|
using a large amount of GPU or TPU memory, for example if trying to debug an
|
|
|
|
out-of-memory problem.
|
|
|
|
|
|
|
|
To capture a device memory profile to disk, use
|
|
|
|
{func}`jax.profiler.save_device_memory_profile`. For example, consider the
|
|
|
|
following Python program:
|
|
|
|
|
|
|
|
```python
|
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
|
|
|
import jax.profiler
|
|
|
|
|
|
|
|
def func1(x):
|
|
|
|
return jnp.tile(x, 10) * 0.5
|
|
|
|
|
|
|
|
def func2(x):
|
|
|
|
y = func1(x)
|
|
|
|
return y, jnp.tile(x, 10) + 1
|
|
|
|
|
2023-08-14 16:31:17 -07:00
|
|
|
x = jax.random.normal(jax.random.key(42), (1000, 1000))
|
2020-06-26 17:09:09 -04:00
|
|
|
y, z = func2(x)
|
|
|
|
|
|
|
|
z.block_until_ready()
|
|
|
|
|
|
|
|
jax.profiler.save_device_memory_profile("memory.prof")
|
|
|
|
```
|
|
|
|
|
|
|
|
If we first run the program above and then execute
|
|
|
|
|
|
|
|
```shell
|
|
|
|
pprof --web memory.prof
|
|
|
|
```
|
|
|
|
|
|
|
|
`pprof` opens a web browser containing the following visualization of the device
|
|
|
|
memory profile in callgraph format:
|
|
|
|
|
|
|
|

|
|
|
|
|
|
|
|
The callgraph is a visualization of
|
|
|
|
the Python stack at the point the allocation of each live buffer was made.
|
|
|
|
For example, in this specific case, the visualization shows that
|
|
|
|
`func2` and its callees were responsible for allocating 76.30MB, of which
|
|
|
|
38.15MB was allocated inside the call from `func1` to `func2`.
|
|
|
|
For more information about how to interpret callgraph visualizations, see the
|
|
|
|
[pprof documentation](https://github.com/google/pprof/blob/master/doc/README.md#interpreting-the-callgraph).
|
|
|
|
|
|
|
|
Functions compiled with {func}`jax.jit` are opaque to the device memory profiler.
|
|
|
|
That is, any memory allocated inside a `jit`-compiled function will be
|
2021-08-02 17:57:09 -07:00
|
|
|
attributed to the function as a whole.
|
2020-06-26 17:09:09 -04:00
|
|
|
|
|
|
|
In the example, the call to `block_until_ready()` is to ensure that `func2`
|
|
|
|
completes before the device memory profile is collected. See
|
|
|
|
{doc}`async_dispatch` for more details.
|
|
|
|
|
|
|
|
## Debugging memory leaks
|
|
|
|
|
|
|
|
We can also use the JAX device memory profiler to track down memory leaks by using
|
|
|
|
`pprof` to visualize the change in memory usage between two device memory profiles
|
2021-08-02 17:57:09 -07:00
|
|
|
taken at different times. For example, consider the following program which
|
2020-06-26 17:09:09 -04:00
|
|
|
accumulates JAX arrays into a constantly-growing Python list.
|
|
|
|
|
|
|
|
```python
|
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
|
|
|
import jax.profiler
|
|
|
|
|
|
|
|
def afunction():
|
2023-08-14 16:31:17 -07:00
|
|
|
return jax.random.normal(jax.random.key(77), (1000000,))
|
2020-06-26 17:09:09 -04:00
|
|
|
|
|
|
|
z = afunction()
|
|
|
|
|
|
|
|
def anotherfunc():
|
|
|
|
arrays = []
|
|
|
|
for i in range(1, 10):
|
2023-08-14 16:31:17 -07:00
|
|
|
x = jax.random.normal(jax.random.key(42), (i, 10000))
|
2020-06-26 17:09:09 -04:00
|
|
|
arrays.append(x)
|
|
|
|
x.block_until_ready()
|
|
|
|
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
|
|
|
|
|
|
|
|
anotherfunc()
|
|
|
|
```
|
|
|
|
|
|
|
|
If we simply visualize the device memory profile at the end of execution
|
|
|
|
(`memory9.prof`), it may not be obvious that each iteration of the loop in
|
|
|
|
`anotherfunc` accumulates more device memory allocations:
|
|
|
|
|
|
|
|
```shell
|
|
|
|
pprof --web memory9.prof
|
|
|
|
```
|
|
|
|
|
|
|
|

|
|
|
|
|
|
|
|
The large but fixed allocation inside `afunction` dominates the profile but does
|
|
|
|
not grow over time.
|
|
|
|
|
|
|
|
By using `pprof`'s
|
|
|
|
[`--diff_base` feature](https://github.com/google/pprof/blob/master/doc/README.md#comparing-profiles) to visualize the change in memory usage
|
|
|
|
across loop iterations, we can identify why the memory usage of the
|
|
|
|
program increases over time:
|
|
|
|
|
|
|
|
```shell
|
|
|
|
pprof --web --diff_base memory1.prof memory9.prof
|
|
|
|
```
|
|
|
|
|
|
|
|

|
|
|
|
|
|
|
|
The visualization shows that the memory growth can be attributed to the call to
|
2021-06-18 08:55:08 +03:00
|
|
|
`normal` inside `anotherfunc`.
|