diff --git a/docs/_static/device_memory_profile.svg b/docs/_static/device_memory_profile.svg
new file mode 100644
index 000000000..fb0b5e498
--- /dev/null
+++ b/docs/_static/device_memory_profile.svg
@@ -0,0 +1,435 @@
+
+
+
+
+
diff --git a/docs/_static/device_memory_profile_leak1.svg b/docs/_static/device_memory_profile_leak1.svg
new file mode 100644
index 000000000..53fc2848a
--- /dev/null
+++ b/docs/_static/device_memory_profile_leak1.svg
@@ -0,0 +1,307 @@
+
+
+
+
+
diff --git a/docs/_static/device_memory_profile_leak2.svg b/docs/_static/device_memory_profile_leak2.svg
new file mode 100644
index 000000000..4817f5c4c
--- /dev/null
+++ b/docs/_static/device_memory_profile_leak2.svg
@@ -0,0 +1,271 @@
+
+
+
+
+
diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md
new file mode 100644
index 000000000..4d34f44da
--- /dev/null
+++ b/docs/device_memory_profiling.md
@@ -0,0 +1,142 @@
+# Device Memory Profiling
+
+The JAX Device Memory Profiler allows us to explore how and why JAX programs are
+using GPU or TPU memory. For example, it can be used to:
+
+* Figure out which arrays and executables are in GPU memory at a given time, or
+* Track down memory leaks.
+
+## Installation
+
+The JAX device memory profiler emits output that can be interpreted using
+pprof (). Start by installing `pprof`,
+by following its
+[installation instructions](https://github.com/google/pprof#building-pprof).
+At the time of writing, installing `pprof` requires first installing
+[Go](https://golang.org/) and [Graphviz](http://www.graphviz.org/), and then
+running
+
+```shell
+go get -u github.com/google/pprof
+```
+
+which installs `pprof` as `$GOPATH/bin/pprof`, where `GOPATH` defaults to
+`~/go`.
+
+```{note}
+The version of `pprof` from is not the same as
+the older tool of the same name distributed as part of the `gperftools` package.
+The `gperftools` version of `pprof` will not work with JAX.
+```
+
+## Understanding how a JAX program is using GPU or TPU memory
+
+A common use of the device memory profiler is to figure out why a JAX program is
+using a large amount of GPU or TPU memory, for example if trying to debug an
+out-of-memory problem.
+
+To capture a device memory profile to disk, use
+{func}`jax.profiler.save_device_memory_profile`. For example, consider the
+following Python program:
+
+```python
+import jax
+import jax.numpy as jnp
+import jax.profiler
+
+def func1(x):
+ return jnp.tile(x, 10) * 0.5
+
+def func2(x):
+ y = func1(x)
+ return y, jnp.tile(x, 10) + 1
+
+x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
+y, z = func2(x)
+
+z.block_until_ready()
+
+jax.profiler.save_device_memory_profile("memory.prof")
+```
+
+If we first run the program above and then execute
+
+```shell
+pprof --web memory.prof
+```
+
+`pprof` opens a web browser containing the following visualization of the device
+memory profile in callgraph format:
+
+
+
+The callgraph is a visualization of
+the Python stack at the point the allocation of each live buffer was made.
+For example, in this specific case, the visualization shows that
+`func2` and its callees were responsible for allocating 76.30MB, of which
+38.15MB was allocated inside the call from `func1` to `func2`.
+For more information about how to interpret callgraph visualizations, see the
+[pprof documentation](https://github.com/google/pprof/blob/master/doc/README.md#interpreting-the-callgraph).
+
+Functions compiled with {func}`jax.jit` are opaque to the device memory profiler.
+That is, any memory allocated inside a `jit`-compiled function will be
+attributed to the function as whole.
+
+In the example, the call to `block_until_ready()` is to ensure that `func2`
+completes before the device memory profile is collected. See
+{doc}`async_dispatch` for more details.
+
+## Debugging memory leaks
+
+We can also use the JAX device memory profiler to track down memory leaks by using
+`pprof` to visualize the change in memory usage between two device memory profiles
+taken at different times. For example consider the following program which
+accumulates JAX arrays into a constantly-growing Python list.
+
+```python
+import jax
+import jax.numpy as jnp
+import jax.profiler
+
+def afunction():
+ return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
+
+z = afunction()
+
+def anotherfunc():
+ arrays = []
+ for i in range(1, 10):
+ x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
+ arrays.append(x)
+ x.block_until_ready()
+ jax.profiler.save_device_memory_profile(f"memory{i}.prof")
+
+anotherfunc()
+```
+
+If we simply visualize the device memory profile at the end of execution
+(`memory9.prof`), it may not be obvious that each iteration of the loop in
+`anotherfunc` accumulates more device memory allocations:
+
+```shell
+pprof --web memory9.prof
+```
+
+
+
+The large but fixed allocation inside `afunction` dominates the profile but does
+not grow over time.
+
+By using `pprof`'s
+[`--diff_base` feature](https://github.com/google/pprof/blob/master/doc/README.md#comparing-profiles) to visualize the change in memory usage
+across loop iterations, we can identify why the memory usage of the
+program increases over time:
+
+```shell
+pprof --web --diff_base memory1.prof memory9.prof
+```
+
+
+
+The visualization shows that the memory growth can be attributed to the call to
+`normal` inside `anotherfunc`.
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index b7e9725ed..bfeb9ddeb 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -38,6 +38,7 @@ For an introduction to JAX, start at the
concurrency
gpu_memory_allocation
profiling
+ device_memory_profiling
pytrees
rank_promotion_warning
type_promotion
diff --git a/docs/jax.profiler.rst b/docs/jax.profiler.rst
index 3f97311f8..cbd6584bf 100644
--- a/docs/jax.profiler.rst
+++ b/docs/jax.profiler.rst
@@ -1,6 +1,33 @@
+.. currentmodule:: jax.profiler
+
jax.profiler module
===================
.. automodule:: jax.profiler
- :members:
- :show-inheritance:
\ No newline at end of file
+
+Tracing and time profiling
+--------------------------
+
+:doc:`profiling` describes how to make use of JAX's tracing and time profiling
+features.
+
+.. autosummary::
+ :toctree: _autosummary
+
+ start_server
+ trace_function
+ TraceContext
+
+
+Device memory profiling
+-----------------------
+
+See :doc:`device_memory_profiling` for an introduction to JAX's device memory
+profiling features.
+
+.. autosummary::
+ :toctree: _autosummary
+
+ device_memory_profile
+ save_device_memory_profile
+
\ No newline at end of file
diff --git a/docs/profiling.md b/docs/profiling.md
index 637307024..25b8724a4 100644
--- a/docs/profiling.md
+++ b/docs/profiling.md
@@ -113,7 +113,7 @@ sudo update-initramfs -u
sudo reboot now
```
-See [Nvidia's documentation on this
+See [NVIDIA's documentation on this
error](https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti)
for more information.
@@ -133,7 +133,7 @@ ssh -L 6006:localhost:6006
## Nsight
-Nvidia's `Nsight` tools can be used to trace and profile JAX code on GPU. For
+NVIDIA's `Nsight` tools can be used to trace and profile JAX code on GPU. For
details, see the [`Nsight`
documentation](https://developer.nvidia.com/tools-overview).
diff --git a/jax/profiler.py b/jax/profiler.py
index 96a20cf36..147122e10 100644
--- a/jax/profiler.py
+++ b/jax/profiler.py
@@ -13,8 +13,9 @@
# limitations under the License.
from functools import wraps
-from typing import Callable
+from typing import Callable, Optional
+from .lib import xla_bridge
from .lib import xla_client
@@ -86,3 +87,48 @@ def trace_function(func: Callable, name: str = None, **kwargs):
return func(*args, **kwargs)
return wrapper
return wrapper
+
+
+def device_memory_profile(backend: Optional[str] = None) -> bytes:
+ """Captures a JAX device memory profile as ``pprof``-format protocol buffer.
+
+ A device memory profile is a snapshot of the state of memory, that describes the JAX
+ :class:`jax.DeviceArray` and executable objects present in memory and their
+ allocation sites.
+
+ For more information how to use the device memory profiler, see
+ :doc:`/device_memory_profiling`.
+
+ The profiling system works by instrumenting JAX on-device allocations,
+ capturing a Python stack trace for each allocation. The instrumentation is
+ always enabled; :func:`device_memory_profile` provides an API to capture it.
+
+ The output of :func:`device_memory_profile` is a binary protocol buffer that
+ can be interpreted and visualized by the `pprof tool
+ `_.
+
+ Args:
+ backend: optional; the name of the JAX backend for which the device memory
+ profile should be collected.
+
+ Returns:
+ A byte string containing a binary `pprof`-format protocol buffer.
+ """
+ return xla_client.heap_profile(xla_bridge.get_backend(backend))
+
+
+def save_device_memory_profile(filename, backend: Optional[str] = None):
+ """Collects a device memory profile and writes it to a file.
+
+ :func:`save_device_memory_profile` is a convenience wrapper around :func:`device_memory_profile`
+ that saves its output to a ``filename``. See the
+ :func:`device_memory_profile` documentation for more information.
+
+ Args:
+ filename: the filename to which the profile should be written.
+ backend: optional; the name of the JAX backend for which the device memory
+ profile should be collected.
+ """
+ profile = device_memory_profile(backend)
+ with open(filename, "wb") as f:
+ f.write(profile)
diff --git a/tests/profiler_test.py b/tests/profiler_test.py
index 2fe9f0de9..818f89fef 100644
--- a/tests/profiler_test.py
+++ b/tests/profiler_test.py
@@ -18,6 +18,7 @@ import unittest
from absl.testing import absltest
import jax
+import jax.numpy as jnp
import jax.profiler
from jax.config import config
import jax.test_util
@@ -61,6 +62,10 @@ class ProfilerTest(unittest.TestCase):
return x + 2
self.assertEqual(h(7), 9)
+ def testDeviceMemoryProfile(self):
+ x = jnp.ones((20,)) + 7.
+ self.assertTrue(isinstance(jax.profiler.device_memory_profile(), bytes))
+ del x
if __name__ == "__main__":
absltest.main()