diff --git a/docs/_static/device_memory_profile.svg b/docs/_static/device_memory_profile.svg new file mode 100644 index 000000000..fb0b5e498 --- /dev/null +++ b/docs/_static/device_memory_profile.svg @@ -0,0 +1,435 @@ + + + + + + +unnamed + + +cluster_L + + + + +Type: space + +Type: space +Showing nodes accounting for 80.11MB, 100% of 80.13MB total +Dropped 26 nodes (cum <= 0.40MB) +See https://git.io/JfYMW for how to read the graph + + + +N1 + + +_execute_compiled_primitive +76.29MB (95.21%) +of 76.29MB (95.21%) + + + + + +N1_0 + + + + + +kind:buffer + + + + + +N1->N1_0 + + + + + + + 76.29MB + + + + + +N2 + + +<unknown> +0 of 80.13MB (100%) + + + + + +N4 + + +func2 +0 of 76.30MB (95.22%) + + + + + +N2->N4 + + + + + + + 76.30MB + + + + + +N15 + + +normal +0 of 3.83MB (4.78%) + + + + + +N2->N15 + + + + + + + 3.83MB + + + + + +N3 + + +bind +0 of 76.30MB (95.22%) + + + + + +N11 + + +apply_primitive +0 of 76.30MB (95.22%) + + + + + +N3->N11 + + + + + + + 76.30MB + + + + + +N5 + + +deferring_binary_op +0 of 76.30MB (95.22%) + + + + + +N4->N5 + + + + + + + 38.15MB + + + + + +N8 + + +func1 +0 of 38.15MB (47.61%) + + + + + +N4->N8 + + + + + + + 38.15MB + + + + + +N6 + + +fn +0 of 76.30MB (95.22%) + + + + + +N5->N6 + + + + + + + 76.30MB + + + + + +N10 + + +add +0 of 38.15MB (47.61%) + + + + + +N6->N10 + + + + + + + 38.15MB + + + + + +N14 + + +mul +0 of 38.15MB (47.61%) + + + + + +N6->N14 + + + + + + + 38.15MB + + + + + +N7 + + +_execute_compiled +3.81MB (4.76%) + + + + + +N7_0 + + + + + +kind:buffer + + + + + +N7->N7_0 + + + + + + + 3.81MB + + + + + +N8->N5 + + + + + + + 38.15MB + + + + + +N9 + + +_xla_call_impl +0 of 3.83MB (4.78%) + + + + + +N9->N7 + + + + + + + 3.81MB + + + + + +N10->N3 + + + + + + + 38.15MB + + + + + +N11->N1 + + + + + + + 76.29MB + + + + + +N12 + + +call_bind +0 of 3.83MB (4.78%) + + + + + +N12->N9 + + + + + + + 3.83MB + + + + + +N13 + + +f_jitted +0 of 3.83MB (4.78%) + + + + + +N13->N12 + + + + + + + 3.83MB + + + + + +N14->N3 + + + + + + + 38.15MB + + + + + +N15->N13 + + + + + + + 3.83MB + + + + + diff --git a/docs/_static/device_memory_profile_leak1.svg b/docs/_static/device_memory_profile_leak1.svg new file mode 100644 index 000000000..53fc2848a --- /dev/null +++ b/docs/_static/device_memory_profile_leak1.svg @@ -0,0 +1,307 @@ + + + + + + +unnamed + + +cluster_L + + + + +Type: space + +Type: space +Showing nodes accounting for 5806.95kB, 99.79% of 5819.22kB total +Dropped 25 nodes (cum <= 29.10kB) +See https://git.io/JfYMW for how to read the graph + + + +N1 + + +_execute_compiled +5664.06kB (97.33%) + + + + + +N1_0 + + + + + +kind:buffer + + + + + +N1->N1_0 + + + + + + + 5664.06kB + + + + + +N2 + + +<unknown> +0 of 5819.22kB (100%) + + + + + +N6 + + +afunction +0 of 3925.29kB (67.45%) + + + + + +N2->N6 + + + + + + + 3925.29kB + + + + + +N7 + + +anotherfunc +0 of 1893.93kB (32.55%) + + + + + +N2->N7 + + + + + + + 1893.93kB + + + + + +N3 + + +normal +0 of 5817.23kB (100%) + + + + + +N9 + + +f_jitted +0 of 5817.23kB (100%) + + + + + +N3->N9 + + + + + + + 5817.23kB + + + + + +N4 + + +_xla_call_impl +0 of 5817.23kB (100%) + + + + + +N4->N1 + + + + + + + 5664.06kB + + + + + +N10 + + +memoized_fun +0 of 153.16kB (2.63%) + + + + + +N4->N10 + + + + + + + 153.16kB + + + + + +N5 + + +_xla_callable +142.89kB (2.46%) +of 153.16kB (2.63%) + + + + + +N5_0 + + + + + +kind:executable + + + + + +N5->N5_0 + + + + + + + 153.16kB + + + + + +N6->N3 + + + + + + + 3923.30kB + + + + + +N7->N3 + + + + + + + 1893.93kB + + + + + +N8 + + +call_bind +0 of 5817.23kB (100%) + + + + + +N8->N4 + + + + + + + 5817.23kB + + + + + +N9->N8 + + + + + + + 5817.23kB + + + + + +N10->N5 + + + + + + + 153.16kB + + + + + diff --git a/docs/_static/device_memory_profile_leak2.svg b/docs/_static/device_memory_profile_leak2.svg new file mode 100644 index 000000000..4817f5c4c --- /dev/null +++ b/docs/_static/device_memory_profile_leak2.svg @@ -0,0 +1,271 @@ + + + + + + +unnamed + + +cluster_L + + + + +Type: space + +Type: space +Showing nodes accounting for 1832.09kB, 46.05% of 3978.91kB total +Dropped 13 nodes (cum <= 19.89kB) +See https://git.io/JfYMW for how to read the graph + + + +N1 + + +_execute_compiled +1718.75kB (43.20%) + + + + + +N1_0 + + + + + +kind:buffer + + + + + +N1->N1_0 + + + + + + + 1718.75kB + + + + + +N2 + + +_xla_call_impl +0 of 1840.31kB (46.25%) + + + + + +N2->N1 + + + + + + + 1718.75kB + + + + + +N7 + + +memoized_fun +0 of 121.56kB (3.06%) + + + + + +N2->N7 + + + + + + + 121.56kB + + + + + +N3 + + +_xla_callable +113.34kB (2.85%) +of 121.56kB (3.06%) + + + + + +N3_0 + + + + + +kind:executable + + + + + +N3->N3_0 + + + + + + + 121.56kB + + + + + +N4 + + +anotherfunc +0 of 1840.31kB (46.25%) + + + + + +N9 + + +normal +0 of 1840.31kB (46.25%) + + + + + +N4->N9 + + + + + + + 1840.31kB + + + + + +N5 + + +call_bind +0 of 1840.31kB (46.25%) + + + + + +N5->N2 + + + + + + + 1840.31kB + + + + + +N6 + + +f_jitted +0 of 1840.31kB (46.25%) + + + + + +N6->N5 + + + + + + + 1840.31kB + + + + + +N7->N3 + + + + + + + 121.56kB + + + + + +N8 + + +<unknown> +0 of 1840.31kB (46.25%) + + + + + +N8->N4 + + + + + + + 1840.31kB + + + + + +N9->N6 + + + + + + + 1840.31kB + + + + + diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md new file mode 100644 index 000000000..4d34f44da --- /dev/null +++ b/docs/device_memory_profiling.md @@ -0,0 +1,142 @@ +# Device Memory Profiling + +The JAX Device Memory Profiler allows us to explore how and why JAX programs are +using GPU or TPU memory. For example, it can be used to: + +* Figure out which arrays and executables are in GPU memory at a given time, or +* Track down memory leaks. + +## Installation + +The JAX device memory profiler emits output that can be interpreted using +pprof (). Start by installing `pprof`, +by following its +[installation instructions](https://github.com/google/pprof#building-pprof). +At the time of writing, installing `pprof` requires first installing +[Go](https://golang.org/) and [Graphviz](http://www.graphviz.org/), and then +running + +```shell +go get -u github.com/google/pprof +``` + +which installs `pprof` as `$GOPATH/bin/pprof`, where `GOPATH` defaults to +`~/go`. + +```{note} +The version of `pprof` from is not the same as +the older tool of the same name distributed as part of the `gperftools` package. +The `gperftools` version of `pprof` will not work with JAX. +``` + +## Understanding how a JAX program is using GPU or TPU memory + +A common use of the device memory profiler is to figure out why a JAX program is +using a large amount of GPU or TPU memory, for example if trying to debug an +out-of-memory problem. + +To capture a device memory profile to disk, use +{func}`jax.profiler.save_device_memory_profile`. For example, consider the +following Python program: + +```python +import jax +import jax.numpy as jnp +import jax.profiler + +def func1(x): + return jnp.tile(x, 10) * 0.5 + +def func2(x): + y = func1(x) + return y, jnp.tile(x, 10) + 1 + +x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000)) +y, z = func2(x) + +z.block_until_ready() + +jax.profiler.save_device_memory_profile("memory.prof") +``` + +If we first run the program above and then execute + +```shell +pprof --web memory.prof +``` + +`pprof` opens a web browser containing the following visualization of the device +memory profile in callgraph format: + +![Device memory profiling example](_static/device_memory_profile.svg) + +The callgraph is a visualization of +the Python stack at the point the allocation of each live buffer was made. +For example, in this specific case, the visualization shows that +`func2` and its callees were responsible for allocating 76.30MB, of which +38.15MB was allocated inside the call from `func1` to `func2`. +For more information about how to interpret callgraph visualizations, see the +[pprof documentation](https://github.com/google/pprof/blob/master/doc/README.md#interpreting-the-callgraph). + +Functions compiled with {func}`jax.jit` are opaque to the device memory profiler. +That is, any memory allocated inside a `jit`-compiled function will be +attributed to the function as whole. + +In the example, the call to `block_until_ready()` is to ensure that `func2` +completes before the device memory profile is collected. See +{doc}`async_dispatch` for more details. + +## Debugging memory leaks + +We can also use the JAX device memory profiler to track down memory leaks by using +`pprof` to visualize the change in memory usage between two device memory profiles +taken at different times. For example consider the following program which +accumulates JAX arrays into a constantly-growing Python list. + +```python +import jax +import jax.numpy as jnp +import jax.profiler + +def afunction(): + return jax.random.normal(jax.random.PRNGKey(77), (1000000,)) + +z = afunction() + +def anotherfunc(): + arrays = [] + for i in range(1, 10): + x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000)) + arrays.append(x) + x.block_until_ready() + jax.profiler.save_device_memory_profile(f"memory{i}.prof") + +anotherfunc() +``` + +If we simply visualize the device memory profile at the end of execution +(`memory9.prof`), it may not be obvious that each iteration of the loop in +`anotherfunc` accumulates more device memory allocations: + +```shell +pprof --web memory9.prof +``` + +![Device memory profile at end of execution](_static/device_memory_profile_leak1.svg) + +The large but fixed allocation inside `afunction` dominates the profile but does +not grow over time. + +By using `pprof`'s +[`--diff_base` feature](https://github.com/google/pprof/blob/master/doc/README.md#comparing-profiles) to visualize the change in memory usage +across loop iterations, we can identify why the memory usage of the +program increases over time: + +```shell +pprof --web --diff_base memory1.prof memory9.prof +``` + +![Device memory profile at end of execution](_static/device_memory_profile_leak2.svg) + +The visualization shows that the memory growth can be attributed to the call to +`normal` inside `anotherfunc`. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index b7e9725ed..bfeb9ddeb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ For an introduction to JAX, start at the concurrency gpu_memory_allocation profiling + device_memory_profiling pytrees rank_promotion_warning type_promotion diff --git a/docs/jax.profiler.rst b/docs/jax.profiler.rst index 3f97311f8..cbd6584bf 100644 --- a/docs/jax.profiler.rst +++ b/docs/jax.profiler.rst @@ -1,6 +1,33 @@ +.. currentmodule:: jax.profiler + jax.profiler module =================== .. automodule:: jax.profiler - :members: - :show-inheritance: \ No newline at end of file + +Tracing and time profiling +-------------------------- + +:doc:`profiling` describes how to make use of JAX's tracing and time profiling +features. + +.. autosummary:: + :toctree: _autosummary + + start_server + trace_function + TraceContext + + +Device memory profiling +----------------------- + +See :doc:`device_memory_profiling` for an introduction to JAX's device memory +profiling features. + +.. autosummary:: + :toctree: _autosummary + + device_memory_profile + save_device_memory_profile + \ No newline at end of file diff --git a/docs/profiling.md b/docs/profiling.md index 637307024..25b8724a4 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -113,7 +113,7 @@ sudo update-initramfs -u sudo reboot now ``` -See [Nvidia's documentation on this +See [NVIDIA's documentation on this error](https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti) for more information. @@ -133,7 +133,7 @@ ssh -L 6006:localhost:6006 ## Nsight -Nvidia's `Nsight` tools can be used to trace and profile JAX code on GPU. For +NVIDIA's `Nsight` tools can be used to trace and profile JAX code on GPU. For details, see the [`Nsight` documentation](https://developer.nvidia.com/tools-overview). diff --git a/jax/profiler.py b/jax/profiler.py index 96a20cf36..147122e10 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -13,8 +13,9 @@ # limitations under the License. from functools import wraps -from typing import Callable +from typing import Callable, Optional +from .lib import xla_bridge from .lib import xla_client @@ -86,3 +87,48 @@ def trace_function(func: Callable, name: str = None, **kwargs): return func(*args, **kwargs) return wrapper return wrapper + + +def device_memory_profile(backend: Optional[str] = None) -> bytes: + """Captures a JAX device memory profile as ``pprof``-format protocol buffer. + + A device memory profile is a snapshot of the state of memory, that describes the JAX + :class:`jax.DeviceArray` and executable objects present in memory and their + allocation sites. + + For more information how to use the device memory profiler, see + :doc:`/device_memory_profiling`. + + The profiling system works by instrumenting JAX on-device allocations, + capturing a Python stack trace for each allocation. The instrumentation is + always enabled; :func:`device_memory_profile` provides an API to capture it. + + The output of :func:`device_memory_profile` is a binary protocol buffer that + can be interpreted and visualized by the `pprof tool + `_. + + Args: + backend: optional; the name of the JAX backend for which the device memory + profile should be collected. + + Returns: + A byte string containing a binary `pprof`-format protocol buffer. + """ + return xla_client.heap_profile(xla_bridge.get_backend(backend)) + + +def save_device_memory_profile(filename, backend: Optional[str] = None): + """Collects a device memory profile and writes it to a file. + + :func:`save_device_memory_profile` is a convenience wrapper around :func:`device_memory_profile` + that saves its output to a ``filename``. See the + :func:`device_memory_profile` documentation for more information. + + Args: + filename: the filename to which the profile should be written. + backend: optional; the name of the JAX backend for which the device memory + profile should be collected. + """ + profile = device_memory_profile(backend) + with open(filename, "wb") as f: + f.write(profile) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 2fe9f0de9..818f89fef 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -18,6 +18,7 @@ import unittest from absl.testing import absltest import jax +import jax.numpy as jnp import jax.profiler from jax.config import config import jax.test_util @@ -61,6 +62,10 @@ class ProfilerTest(unittest.TestCase): return x + 2 self.assertEqual(h(7), 9) + def testDeviceMemoryProfile(self): + x = jnp.ones((20,)) + 7. + self.assertTrue(isinstance(jax.profiler.device_memory_profile(), bytes)) + del x if __name__ == "__main__": absltest.main()