Add programmatic profiling APIs, and rename some existing APIs.

This change provides aliases for the renamed APIs so existing code
won't break. We should remove these aliases after the next release.
This commit is contained in:
Skye Wanderman-Milne 2021-03-17 17:33:25 +00:00
parent b8812b2a5d
commit b68a08adf1
6 changed files with 258 additions and 28 deletions

View File

@ -10,9 +10,17 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.12 (unreleased) ## jax 0.2.12 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* Breaking changes: * Breaking changes:
* The minimum jaxlib version is now 0.1.64. * The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
## jaxlib 0.1.65 (unreleased) ## jaxlib 0.1.65 (unreleased)

View File

@ -15,8 +15,12 @@ features.
:toctree: _autosummary :toctree: _autosummary
start_server start_server
trace_function start_trace
TraceContext stop_trace
trace
annotate_function
TraceAnnotation
StepTraceAnnotation
Device memory profiling Device memory profiling
@ -30,4 +34,3 @@ profiling features.
device_memory_profile device_memory_profile
save_device_memory_profile save_device_memory_profile

View File

@ -18,7 +18,69 @@ Install specific nightly versions of TensorBoard, TensorBoard profiler, TensorFl
pip install --upgrade tb-nightly==2.5.0a20201203 tbp-nightly==2.4.0a20201203 tf-nightly==2.5.0.dev20201203 tensorboard-plugin-wit==1.7.0 pip install --upgrade tb-nightly==2.5.0a20201203 tbp-nightly==2.4.0a20201203 tf-nightly==2.5.0.dev20201203 tensorboard-plugin-wit==1.7.0
``` ```
### Usage ### Programmatic capture
You can instrument your code to capture a profiler trace via the
{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace`
methods. Call {func}`~jax.profiler.start_trace` with the directory to write
trace files to. This should be the same `--logdir` directory used to start
TensorBoard. Then, you can use TensorBoard to view the traces.
For example, to take a profiler trace:
```python
import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
```
Note the {func}`block_until_ready` call. We use this to make sure on-device
execution is captured by the trace. See {ref}`async-dispatch` for details on why
this is necessary.
You can also use the {func}`jax.profiler.trace` context manager as an
alternative to `start_trace` and `stop_trace`:
```python
import jax
with jax.profiler.trace():
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
```
To view the trace, first start TensorBoard if you haven't already:
```shell
$ tensorboard --logdir /tmp/tensorboard
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0a20210304 at http://localhost:6006/ (Press CTRL+C to quit)
```
You should be able to load TensorBoard at <http://localhost:6006/> in this
example. You can specify a different port with the `--port` flag. See
{ref}`remote_profiling` below if running JAX on a remote server.
Then, either select "Profile" in the upper-right dropdown menu, or go directly
to <http://localhost:6006/#profile>. Available traces appear in the "Runs"
dropdown menu on the left. Select the run you're interested in, and then under
"Tools", select "trace_viewer". You should now see a timeline of the
execution. You can use the WASD keys to navigate the trace, and click or drag to
select events to see more details at the bottom. See [these TensorFlow
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
for more details on using the trace viewer.
### Manual capture via TensorBoard
The following are instructions for capturing a manually-triggered N-second trace The following are instructions for capturing a manually-triggered N-second trace
from a running program. from a running program.
@ -43,7 +105,7 @@ from a running program.
This starts the profiler server that TensorBoard connects to. The profiler This starts the profiler server that TensorBoard connects to. The profiler
server must be running before you move on to the next step. It will remain server must be running before you move on to the next step. It will remain
alive and listening until the object returned by `start_server()` is alive and listening until the object returned by `start_server()` is
destroyed. destroyed.
If you'd like to profile a snippet of a long-running program (e.g. a long If you'd like to profile a snippet of a long-running program (e.g. a long
@ -76,10 +138,12 @@ from a running program.
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
for more details on using the trace viewer.<br /><br /> for more details on using the trace viewer.<br /><br />
1. By default, the events in the trace viewer are mostly low-level internal JAX ### Adding custom trace events
functions. You can add your own events and functions by using
{func}`jax.profiler.TraceContext` and {func}`jax.profiler.trace_function` in By default, the events in the trace viewer are mostly low-level internal JAX
your code and capturing a new profile. functions. You can add your own events and functions by using
{class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in
your code.
### Troubleshooting ### Troubleshooting

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from functools import wraps from functools import wraps
import threading
from typing import Callable, Optional from typing import Callable, Optional
import warnings
from jax.lib import xla_bridge from jax.lib import xla_bridge
from jax.lib import xla_client from jax.lib import xla_client
@ -24,8 +27,8 @@ def start_server(port: int):
Using the "TensorFlow profiler" feature in `TensorBoard Using the "TensorFlow profiler" feature in `TensorBoard
<https://www.tensorflow.org/tensorboard>`_ 2.2 or newer, you can <https://www.tensorflow.org/tensorboard>`_ 2.2 or newer, you can
connect to the profiler server and sample execution traces that show CPU and connect to the profiler server and sample execution traces that show CPU,
GPU device activity. GPU, and/or TPU device activity.
Returns a profiler server object. The server remains alive and listening until Returns a profiler server object. The server remains alive and listening until
the server object is destroyed. the server object is destroyed.
@ -33,8 +36,80 @@ def start_server(port: int):
return xla_client.profiler.start_server(port) return xla_client.profiler.start_server(port)
class TraceContext(xla_client.profiler.TraceMe): class _ProfileState(object):
"""Context manager generates a trace event in the profiler. def __init__(self):
self.profile_session = None
self.log_dir = None
self.lock = threading.Lock()
_profile_state = _ProfileState()
def start_trace(log_dir):
"""Starts a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
and save the results to ``log_dir``.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if
``start_trace()`` is called while another trace is running.
Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
"""
with _profile_state.lock:
if _profile_state.profile_session is not None:
raise RuntimeError("Profile has already been started. "
"Only one profile may be run at a time.")
_profile_state.profile_session = xla_client.profiler.ProfilerSession()
_profile_state.log_dir = log_dir
def stop_trace():
"""Stops the currently-running profiler trace.
The trace will be saved to the ``log_dir`` passed to the corresponding
``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started.
"""
with _profile_state.lock:
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
_profile_state.profile_session.stop_and_export(_profile_state.log_dir)
_profile_state.profile_session = None
_profile_state.log_dir = None
@contextmanager
def trace(log_dir):
"""Context manager to take a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
functions and JAX on-device operations.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if a
trace is started while another trace is running.
Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
"""
start_trace(log_dir)
try:
yield
finally:
stop_trace()
class TraceAnnotation(xla_client.profiler.TraceMe):
"""Context manager that generates a trace event in the profiler.
The trace event spans the duration of the code enclosed by the context. The trace event spans the duration of the code enclosed by the context.
@ -42,16 +117,25 @@ class TraceContext(xla_client.profiler.TraceMe):
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> x = jnp.ones((1000, 1000)) >>> x = jnp.ones((1000, 1000))
>>> with jax.profiler.TraceContext("acontext"): >>> with jax.profiler.TraceAnnotation("my_label"):
... jnp.dot(x, x.T).block_until_ready() ... jnp.dot(x, x.T).block_until_ready()
This will cause an "acontext" event to show up on the trace timeline if the This will cause a "my_label" event to show up on the trace timeline if the
event occurs while the process is being traced by TensorBoard. event occurs while the process is being traced.
""" """
pass pass
class StepTraceContext(TraceContext): # TODO: remove this sometime after jax 0.1.11 is released
class TraceContext(TraceAnnotation):
def __init__(self, *args, **kwargs):
warnings.warn(
"TraceContext has been renamed to TraceAnnotation. This alias "
"will eventually be removed; please update your code.")
super().__init__(*args, **kwargs)
class StepTraceAnnotation(TraceAnnotation):
"""Context manager that generates a step trace event in the profiler. """Context manager that generates a step trace event in the profiler.
The step trace event spans the duration of the code enclosed by the context. The step trace event spans the duration of the code enclosed by the context.
@ -63,7 +147,7 @@ class StepTraceContext(TraceContext):
>>> import jax >>> import jax
>>> >>>
>>> while global_step < NUM_STEPS: >>> while global_step < NUM_STEPS:
... with jax.profiler.StepTraceContext("train", step_num=global_step): ... with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
... train_step() ... train_step()
... global_step += 1 ... global_step += 1
@ -79,14 +163,23 @@ class StepTraceContext(TraceContext):
super().__init__(name, _r=1, **kwargs) super().__init__(name, _r=1, **kwargs)
def trace_function(func: Callable, name: str = None, **kwargs): # TODO: remove this sometime after jax 0.1.11 is released
class StepTraceContext(StepTraceAnnotation):
def __init__(self, *args, **kwargs):
warnings.warn(
"StepTraceContext has been renamed to StepTraceAnnotation. This alias "
"will eventually be removed; please update your code.")
super().__init__(*args, **kwargs)
def annotate_function(func: Callable, name: str = None, **kwargs):
"""Decorator that generates a trace event for the execution of a function. """Decorator that generates a trace event for the execution of a function.
For example: For example:
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> >>>
>>> @jax.profiler.trace_function >>> @jax.profiler.annotate_function
>>> def f(x): >>> def f(x):
... return jnp.dot(x, x.T).block_until_ready() ... return jnp.dot(x, x.T).block_until_ready()
>>> >>>
@ -111,12 +204,21 @@ def trace_function(func: Callable, name: str = None, **kwargs):
name = name or func.__name__ name = name or func.__name__
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
with TraceContext(name, **kwargs): with TraceAnnotation(name, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
return wrapper return wrapper
# TODO: remove this sometime after jax 0.1.11 is released
def trace_function(*args, **kwargs):
warnings.warn(
"trace_function has been renamed to annotate_function. This alias "
"will eventually be removed; please update your code.")
return annotate_function(*args, **kwargs)
def device_memory_profile(backend: Optional[str] = None) -> bytes: def device_memory_profile(backend: Optional[str] = None) -> bytes:
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer. """Captures a JAX device memory profile as ``pprof``-format protocol buffer.

View File

@ -14,10 +14,16 @@
# flake8: noqa: F401 # flake8: noqa: F401
from jax._src.profiler import ( from jax._src.profiler import (
StepTraceAnnotation,
StepTraceContext, StepTraceContext,
TraceAnnotation,
TraceContext, TraceContext,
device_memory_profile, device_memory_profile,
save_device_memory_profile, save_device_memory_profile,
start_server, start_server,
start_trace,
stop_trace,
trace,
annotate_function,
trace_function, trace_function,
) )

View File

@ -16,6 +16,7 @@ from functools import partial
import glob import glob
import os import os
import shutil import shutil
import tempfile
import threading import threading
import unittest import unittest
from absl.testing import absltest from absl.testing import absltest
@ -56,23 +57,69 @@ class ProfilerTest(unittest.TestCase):
jax.profiler.start_server(port=port) jax.profiler.start_server(port=port)
del port del port
def testTraceContext(self): def testProgrammaticProfiling(self):
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
jnp.ones(jax.local_device_count()))
jax.profiler.stop_trace()
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
recursive=True)
self.assertEqual(len(proto_path), 1)
with open(proto_path[0], "rb") as f:
proto = f.read()
# Sanity check that serialized proto contains host and device traces
# without deserializing.
self.assertIn(b"/host:CPU", proto)
if jtu.device_under_test() == "tpu":
self.assertIn(b"/device:TPU", proto)
def testProgrammaticProfilingErrors(self):
with self.assertRaisesRegex(RuntimeError, "No profile started"):
jax.profiler.stop_trace()
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)
with self.assertRaisesRegex(RuntimeError,
"Profile has already been started. Only one "
"profile may be run at a time."):
jax.profiler.start_trace(tmpdir)
def testProgrammaticProfilingContextManager(self):
with tempfile.TemporaryDirectory() as tmpdir:
with jax.profiler.trace(tmpdir):
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
jnp.ones(jax.local_device_count()))
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
recursive=True)
self.assertEqual(len(proto_path), 1)
with open(proto_path[0], "rb") as f:
proto = f.read()
# Sanity check that serialized proto contains host and device traces
# without deserializing.
self.assertIn(b"/host:CPU", proto)
if jtu.device_under_test() == "tpu":
self.assertIn(b"/device:TPU", proto)
def testTraceAnnotation(self):
x = 3 x = 3
with jax.profiler.TraceContext("mycontext"): with jax.profiler.TraceAnnotation("mycontext"):
x = x + 2 x = x + 2
def testTraceFunction(self): def testTraceFunction(self):
@jax.profiler.trace_function @jax.profiler.annotate_function
def f(x): def f(x):
return x + 2 return x + 2
self.assertEqual(f(7), 9) self.assertEqual(f(7), 9)
@partial(jax.profiler.trace_function, name="aname") @partial(jax.profiler.annotate_function, name="aname")
def g(x): def g(x):
return x + 2 return x + 2
self.assertEqual(g(7), 9) self.assertEqual(g(7), 9)
@partial(jax.profiler.trace_function, name="aname", akwarg="hello") @partial(jax.profiler.annotate_function, name="aname", akwarg="hello")
def h(x): def h(x):
return x + 2 return x + 2
self.assertEqual(h(7), 9) self.assertEqual(h(7), 9)
@ -96,7 +143,7 @@ class ProfilerTest(unittest.TestCase):
worker_start.set() worker_start.set()
x = jnp.ones((1000, 1000)) x = jnp.ones((1000, 1000))
while True: while True:
with jax.profiler.TraceContext("atracecontext"): with jax.profiler.TraceAnnotation("atraceannotation"):
jnp.dot(x, x.T).block_until_ready() jnp.dot(x, x.T).block_until_ready()
if self.profile_done: if self.profile_done:
break break