diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2dc46bf81..7c383aa73 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,9 +10,17 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.12 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master).
-
+* New features
+ * New profiling APIs: {func}`jax.profiler.start_trace`,
+ {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
+ * Some profiler APIs names have been changed. There are still aliases, so this
+ should not break existing code, but the aliases will eventually be removed
+ so please change your code.
+ * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
+ * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
+ * `trace_function` --> {func}`~jax.profiler.annotate_function`
## jaxlib 0.1.65 (unreleased)
diff --git a/docs/jax.profiler.rst b/docs/jax.profiler.rst
index cbd6584bf..7e99fc388 100644
--- a/docs/jax.profiler.rst
+++ b/docs/jax.profiler.rst
@@ -15,8 +15,12 @@ features.
:toctree: _autosummary
start_server
- trace_function
- TraceContext
+ start_trace
+ stop_trace
+ trace
+ annotate_function
+ TraceAnnotation
+ StepTraceAnnotation
Device memory profiling
@@ -30,4 +34,3 @@ profiling features.
device_memory_profile
save_device_memory_profile
-
\ No newline at end of file
diff --git a/docs/profiling.md b/docs/profiling.md
index d7244bd91..1fe324b45 100644
--- a/docs/profiling.md
+++ b/docs/profiling.md
@@ -18,7 +18,69 @@ Install specific nightly versions of TensorBoard, TensorBoard profiler, TensorFl
pip install --upgrade tb-nightly==2.5.0a20201203 tbp-nightly==2.4.0a20201203 tf-nightly==2.5.0.dev20201203 tensorboard-plugin-wit==1.7.0
```
-### Usage
+### Programmatic capture
+
+You can instrument your code to capture a profiler trace via the
+{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace`
+methods. Call {func}`~jax.profiler.start_trace` with the directory to write
+trace files to. This should be the same `--logdir` directory used to start
+TensorBoard. Then, you can use TensorBoard to view the traces.
+
+For example, to take a profiler trace:
+
+```python
+import jax
+
+jax.profiler.start_trace("/tmp/tensorboard")
+
+# Run the operations to be profiled
+key = jax.random.PRNGKey(0)
+x = jax.random.normal(key, (5000, 5000))
+y = x @ x
+y.block_until_ready()
+
+jax.profiler.stop_trace()
+```
+
+Note the {func}`block_until_ready` call. We use this to make sure on-device
+execution is captured by the trace. See {ref}`async-dispatch` for details on why
+this is necessary.
+
+You can also use the {func}`jax.profiler.trace` context manager as an
+alternative to `start_trace` and `stop_trace`:
+
+```python
+import jax
+
+with jax.profiler.trace():
+ key = jax.random.PRNGKey(0)
+ x = jax.random.normal(key, (5000, 5000))
+ y = x @ x
+ y.block_until_ready()
+```
+
+To view the trace, first start TensorBoard if you haven't already:
+
+```shell
+$ tensorboard --logdir /tmp/tensorboard
+Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+TensorBoard 2.5.0a20210304 at http://localhost:6006/ (Press CTRL+C to quit)
+```
+
+You should be able to load TensorBoard at in this
+example. You can specify a different port with the `--port` flag. See
+{ref}`remote_profiling` below if running JAX on a remote server.
+
+Then, either select "Profile" in the upper-right dropdown menu, or go directly
+to . Available traces appear in the "Runs"
+dropdown menu on the left. Select the run you're interested in, and then under
+"Tools", select "trace_viewer". You should now see a timeline of the
+execution. You can use the WASD keys to navigate the trace, and click or drag to
+select events to see more details at the bottom. See [these TensorFlow
+docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
+for more details on using the trace viewer.
+
+### Manual capture via TensorBoard
The following are instructions for capturing a manually-triggered N-second trace
from a running program.
@@ -43,7 +105,7 @@ from a running program.
This starts the profiler server that TensorBoard connects to. The profiler
server must be running before you move on to the next step. It will remain
- alive and listening until the object returned by `start_server()` is
+ alive and listening until the object returned by `start_server()` is
destroyed.
If you'd like to profile a snippet of a long-running program (e.g. a long
@@ -76,10 +138,12 @@ from a running program.
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
for more details on using the trace viewer.
-1. By default, the events in the trace viewer are mostly low-level internal JAX
- functions. You can add your own events and functions by using
- {func}`jax.profiler.TraceContext` and {func}`jax.profiler.trace_function` in
- your code and capturing a new profile.
+### Adding custom trace events
+
+By default, the events in the trace viewer are mostly low-level internal JAX
+functions. You can add your own events and functions by using
+{class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in
+your code.
### Troubleshooting
diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py
index b8a7fa265..192e4487a 100644
--- a/jax/_src/profiler.py
+++ b/jax/_src/profiler.py
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from contextlib import contextmanager
from functools import wraps
+import threading
from typing import Callable, Optional
+import warnings
from jax.lib import xla_bridge
from jax.lib import xla_client
@@ -24,8 +27,8 @@ def start_server(port: int):
Using the "TensorFlow profiler" feature in `TensorBoard
`_ 2.2 or newer, you can
- connect to the profiler server and sample execution traces that show CPU and
- GPU device activity.
+ connect to the profiler server and sample execution traces that show CPU,
+ GPU, and/or TPU device activity.
Returns a profiler server object. The server remains alive and listening until
the server object is destroyed.
@@ -33,8 +36,80 @@ def start_server(port: int):
return xla_client.profiler.start_server(port)
-class TraceContext(xla_client.profiler.TraceMe):
- """Context manager generates a trace event in the profiler.
+class _ProfileState(object):
+ def __init__(self):
+ self.profile_session = None
+ self.log_dir = None
+ self.lock = threading.Lock()
+
+_profile_state = _ProfileState()
+
+
+def start_trace(log_dir):
+ """Starts a profiler trace.
+
+ The trace will capture CPU, GPU, and/or TPU activity, including Python
+ functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
+ and save the results to ``log_dir``.
+
+ The resulting trace can be viewed with TensorBoard. Note that TensorBoard
+ doesn't need to be running when collecting the trace.
+
+ Only once trace may be collected a time. A RuntimeError will be raised if
+ ``start_trace()`` is called while another trace is running.
+
+ Args:
+ log_dir: The directory to save the profiler trace to (usually the
+ TensorBoard log directory).
+ """
+ with _profile_state.lock:
+ if _profile_state.profile_session is not None:
+ raise RuntimeError("Profile has already been started. "
+ "Only one profile may be run at a time.")
+ _profile_state.profile_session = xla_client.profiler.ProfilerSession()
+ _profile_state.log_dir = log_dir
+
+
+def stop_trace():
+ """Stops the currently-running profiler trace.
+
+ The trace will be saved to the ``log_dir`` passed to the corresponding
+ ``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started.
+ """
+ with _profile_state.lock:
+ if _profile_state.profile_session is None:
+ raise RuntimeError("No profile started")
+ _profile_state.profile_session.stop_and_export(_profile_state.log_dir)
+ _profile_state.profile_session = None
+ _profile_state.log_dir = None
+
+
+@contextmanager
+def trace(log_dir):
+ """Context manager to take a profiler trace.
+
+ The trace will capture CPU, GPU, and/or TPU activity, including Python
+ functions and JAX on-device operations.
+
+ The resulting trace can be viewed with TensorBoard. Note that TensorBoard
+ doesn't need to be running when collecting the trace.
+
+ Only once trace may be collected a time. A RuntimeError will be raised if a
+ trace is started while another trace is running.
+
+ Args:
+ log_dir: The directory to save the profiler trace to (usually the
+ TensorBoard log directory).
+ """
+ start_trace(log_dir)
+ try:
+ yield
+ finally:
+ stop_trace()
+
+
+class TraceAnnotation(xla_client.profiler.TraceMe):
+ """Context manager that generates a trace event in the profiler.
The trace event spans the duration of the code enclosed by the context.
@@ -42,16 +117,25 @@ class TraceContext(xla_client.profiler.TraceMe):
>>> import jax, jax.numpy as jnp
>>> x = jnp.ones((1000, 1000))
- >>> with jax.profiler.TraceContext("acontext"):
+ >>> with jax.profiler.TraceAnnotation("my_label"):
... jnp.dot(x, x.T).block_until_ready()
- This will cause an "acontext" event to show up on the trace timeline if the
- event occurs while the process is being traced by TensorBoard.
+ This will cause a "my_label" event to show up on the trace timeline if the
+ event occurs while the process is being traced.
"""
pass
-class StepTraceContext(TraceContext):
+# TODO: remove this sometime after jax 0.1.11 is released
+class TraceContext(TraceAnnotation):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ "TraceContext has been renamed to TraceAnnotation. This alias "
+ "will eventually be removed; please update your code.")
+ super().__init__(*args, **kwargs)
+
+
+class StepTraceAnnotation(TraceAnnotation):
"""Context manager that generates a step trace event in the profiler.
The step trace event spans the duration of the code enclosed by the context.
@@ -63,7 +147,7 @@ class StepTraceContext(TraceContext):
>>> import jax
>>>
>>> while global_step < NUM_STEPS:
- ... with jax.profiler.StepTraceContext("train", step_num=global_step):
+ ... with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
... train_step()
... global_step += 1
@@ -79,14 +163,23 @@ class StepTraceContext(TraceContext):
super().__init__(name, _r=1, **kwargs)
-def trace_function(func: Callable, name: str = None, **kwargs):
+# TODO: remove this sometime after jax 0.1.11 is released
+class StepTraceContext(StepTraceAnnotation):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ "StepTraceContext has been renamed to StepTraceAnnotation. This alias "
+ "will eventually be removed; please update your code.")
+ super().__init__(*args, **kwargs)
+
+
+def annotate_function(func: Callable, name: str = None, **kwargs):
"""Decorator that generates a trace event for the execution of a function.
For example:
>>> import jax, jax.numpy as jnp
>>>
- >>> @jax.profiler.trace_function
+ >>> @jax.profiler.annotate_function
>>> def f(x):
... return jnp.dot(x, x.T).block_until_ready()
>>>
@@ -111,12 +204,21 @@ def trace_function(func: Callable, name: str = None, **kwargs):
name = name or func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
- with TraceContext(name, **kwargs):
+ with TraceAnnotation(name, **kwargs):
return func(*args, **kwargs)
return wrapper
return wrapper
+# TODO: remove this sometime after jax 0.1.11 is released
+def trace_function(*args, **kwargs):
+ warnings.warn(
+ "trace_function has been renamed to annotate_function. This alias "
+ "will eventually be removed; please update your code.")
+ return annotate_function(*args, **kwargs)
+
+
+
def device_memory_profile(backend: Optional[str] = None) -> bytes:
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.
diff --git a/jax/profiler.py b/jax/profiler.py
index a66eb9f8e..e8e5db9e6 100644
--- a/jax/profiler.py
+++ b/jax/profiler.py
@@ -14,10 +14,16 @@
# flake8: noqa: F401
from jax._src.profiler import (
+ StepTraceAnnotation,
StepTraceContext,
+ TraceAnnotation,
TraceContext,
device_memory_profile,
save_device_memory_profile,
start_server,
+ start_trace,
+ stop_trace,
+ trace,
+ annotate_function,
trace_function,
)
diff --git a/tests/profiler_test.py b/tests/profiler_test.py
index c7aa297b1..21ceb632c 100644
--- a/tests/profiler_test.py
+++ b/tests/profiler_test.py
@@ -16,6 +16,7 @@ from functools import partial
import glob
import os
import shutil
+import tempfile
import threading
import unittest
from absl.testing import absltest
@@ -56,23 +57,69 @@ class ProfilerTest(unittest.TestCase):
jax.profiler.start_server(port=port)
del port
- def testTraceContext(self):
+ def testProgrammaticProfiling(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ jax.profiler.start_trace(tmpdir)
+ jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
+ jnp.ones(jax.local_device_count()))
+ jax.profiler.stop_trace()
+
+ proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
+ recursive=True)
+ self.assertEqual(len(proto_path), 1)
+ with open(proto_path[0], "rb") as f:
+ proto = f.read()
+ # Sanity check that serialized proto contains host and device traces
+ # without deserializing.
+ self.assertIn(b"/host:CPU", proto)
+ if jtu.device_under_test() == "tpu":
+ self.assertIn(b"/device:TPU", proto)
+
+ def testProgrammaticProfilingErrors(self):
+ with self.assertRaisesRegex(RuntimeError, "No profile started"):
+ jax.profiler.stop_trace()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ jax.profiler.start_trace(tmpdir)
+ with self.assertRaisesRegex(RuntimeError,
+ "Profile has already been started. Only one "
+ "profile may be run at a time."):
+ jax.profiler.start_trace(tmpdir)
+
+ def testProgrammaticProfilingContextManager(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ with jax.profiler.trace(tmpdir):
+ jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
+ jnp.ones(jax.local_device_count()))
+
+ proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
+ recursive=True)
+ self.assertEqual(len(proto_path), 1)
+ with open(proto_path[0], "rb") as f:
+ proto = f.read()
+ # Sanity check that serialized proto contains host and device traces
+ # without deserializing.
+ self.assertIn(b"/host:CPU", proto)
+ if jtu.device_under_test() == "tpu":
+ self.assertIn(b"/device:TPU", proto)
+
+ def testTraceAnnotation(self):
x = 3
- with jax.profiler.TraceContext("mycontext"):
+ with jax.profiler.TraceAnnotation("mycontext"):
x = x + 2
def testTraceFunction(self):
- @jax.profiler.trace_function
+ @jax.profiler.annotate_function
def f(x):
return x + 2
self.assertEqual(f(7), 9)
- @partial(jax.profiler.trace_function, name="aname")
+ @partial(jax.profiler.annotate_function, name="aname")
def g(x):
return x + 2
self.assertEqual(g(7), 9)
- @partial(jax.profiler.trace_function, name="aname", akwarg="hello")
+ @partial(jax.profiler.annotate_function, name="aname", akwarg="hello")
def h(x):
return x + 2
self.assertEqual(h(7), 9)
@@ -96,7 +143,7 @@ class ProfilerTest(unittest.TestCase):
worker_start.set()
x = jnp.ones((1000, 1000))
while True:
- with jax.profiler.TraceContext("atracecontext"):
+ with jax.profiler.TraceAnnotation("atraceannotation"):
jnp.dot(x, x.T).block_until_ready()
if self.profile_done:
break