Add programmatic profiling APIs, and rename some existing APIs.

This change provides aliases for the renamed APIs so existing code
won't break. We should remove these aliases after the next release.
This commit is contained in:
Skye Wanderman-Milne 2021-03-17 17:33:25 +00:00
parent b8812b2a5d
commit b68a08adf1
6 changed files with 258 additions and 28 deletions

View File

@ -10,9 +10,17 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.12 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
## jaxlib 0.1.65 (unreleased)

View File

@ -15,8 +15,12 @@ features.
:toctree: _autosummary
start_server
trace_function
TraceContext
start_trace
stop_trace
trace
annotate_function
TraceAnnotation
StepTraceAnnotation
Device memory profiling
@ -30,4 +34,3 @@ profiling features.
device_memory_profile
save_device_memory_profile

View File

@ -18,7 +18,69 @@ Install specific nightly versions of TensorBoard, TensorBoard profiler, TensorFl
pip install --upgrade tb-nightly==2.5.0a20201203 tbp-nightly==2.4.0a20201203 tf-nightly==2.5.0.dev20201203 tensorboard-plugin-wit==1.7.0
```
### Usage
### Programmatic capture
You can instrument your code to capture a profiler trace via the
{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace`
methods. Call {func}`~jax.profiler.start_trace` with the directory to write
trace files to. This should be the same `--logdir` directory used to start
TensorBoard. Then, you can use TensorBoard to view the traces.
For example, to take a profiler trace:
```python
import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
```
Note the {func}`block_until_ready` call. We use this to make sure on-device
execution is captured by the trace. See {ref}`async-dispatch` for details on why
this is necessary.
You can also use the {func}`jax.profiler.trace` context manager as an
alternative to `start_trace` and `stop_trace`:
```python
import jax
with jax.profiler.trace():
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
```
To view the trace, first start TensorBoard if you haven't already:
```shell
$ tensorboard --logdir /tmp/tensorboard
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0a20210304 at http://localhost:6006/ (Press CTRL+C to quit)
```
You should be able to load TensorBoard at <http://localhost:6006/> in this
example. You can specify a different port with the `--port` flag. See
{ref}`remote_profiling` below if running JAX on a remote server.
Then, either select "Profile" in the upper-right dropdown menu, or go directly
to <http://localhost:6006/#profile>. Available traces appear in the "Runs"
dropdown menu on the left. Select the run you're interested in, and then under
"Tools", select "trace_viewer". You should now see a timeline of the
execution. You can use the WASD keys to navigate the trace, and click or drag to
select events to see more details at the bottom. See [these TensorFlow
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
for more details on using the trace viewer.
### Manual capture via TensorBoard
The following are instructions for capturing a manually-triggered N-second trace
from a running program.
@ -43,7 +105,7 @@ from a running program.
This starts the profiler server that TensorBoard connects to. The profiler
server must be running before you move on to the next step. It will remain
alive and listening until the object returned by `start_server()` is
alive and listening until the object returned by `start_server()` is
destroyed.
If you'd like to profile a snippet of a long-running program (e.g. a long
@ -76,10 +138,12 @@ from a running program.
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
for more details on using the trace viewer.<br /><br />
1. By default, the events in the trace viewer are mostly low-level internal JAX
functions. You can add your own events and functions by using
{func}`jax.profiler.TraceContext` and {func}`jax.profiler.trace_function` in
your code and capturing a new profile.
### Adding custom trace events
By default, the events in the trace viewer are mostly low-level internal JAX
functions. You can add your own events and functions by using
{class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in
your code.
### Troubleshooting

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from functools import wraps
import threading
from typing import Callable, Optional
import warnings
from jax.lib import xla_bridge
from jax.lib import xla_client
@ -24,8 +27,8 @@ def start_server(port: int):
Using the "TensorFlow profiler" feature in `TensorBoard
<https://www.tensorflow.org/tensorboard>`_ 2.2 or newer, you can
connect to the profiler server and sample execution traces that show CPU and
GPU device activity.
connect to the profiler server and sample execution traces that show CPU,
GPU, and/or TPU device activity.
Returns a profiler server object. The server remains alive and listening until
the server object is destroyed.
@ -33,8 +36,80 @@ def start_server(port: int):
return xla_client.profiler.start_server(port)
class TraceContext(xla_client.profiler.TraceMe):
"""Context manager generates a trace event in the profiler.
class _ProfileState(object):
def __init__(self):
self.profile_session = None
self.log_dir = None
self.lock = threading.Lock()
_profile_state = _ProfileState()
def start_trace(log_dir):
"""Starts a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
and save the results to ``log_dir``.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if
``start_trace()`` is called while another trace is running.
Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
"""
with _profile_state.lock:
if _profile_state.profile_session is not None:
raise RuntimeError("Profile has already been started. "
"Only one profile may be run at a time.")
_profile_state.profile_session = xla_client.profiler.ProfilerSession()
_profile_state.log_dir = log_dir
def stop_trace():
"""Stops the currently-running profiler trace.
The trace will be saved to the ``log_dir`` passed to the corresponding
``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started.
"""
with _profile_state.lock:
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
_profile_state.profile_session.stop_and_export(_profile_state.log_dir)
_profile_state.profile_session = None
_profile_state.log_dir = None
@contextmanager
def trace(log_dir):
"""Context manager to take a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
functions and JAX on-device operations.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if a
trace is started while another trace is running.
Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
"""
start_trace(log_dir)
try:
yield
finally:
stop_trace()
class TraceAnnotation(xla_client.profiler.TraceMe):
"""Context manager that generates a trace event in the profiler.
The trace event spans the duration of the code enclosed by the context.
@ -42,16 +117,25 @@ class TraceContext(xla_client.profiler.TraceMe):
>>> import jax, jax.numpy as jnp
>>> x = jnp.ones((1000, 1000))
>>> with jax.profiler.TraceContext("acontext"):
>>> with jax.profiler.TraceAnnotation("my_label"):
... jnp.dot(x, x.T).block_until_ready()
This will cause an "acontext" event to show up on the trace timeline if the
event occurs while the process is being traced by TensorBoard.
This will cause a "my_label" event to show up on the trace timeline if the
event occurs while the process is being traced.
"""
pass
class StepTraceContext(TraceContext):
# TODO: remove this sometime after jax 0.1.11 is released
class TraceContext(TraceAnnotation):
def __init__(self, *args, **kwargs):
warnings.warn(
"TraceContext has been renamed to TraceAnnotation. This alias "
"will eventually be removed; please update your code.")
super().__init__(*args, **kwargs)
class StepTraceAnnotation(TraceAnnotation):
"""Context manager that generates a step trace event in the profiler.
The step trace event spans the duration of the code enclosed by the context.
@ -63,7 +147,7 @@ class StepTraceContext(TraceContext):
>>> import jax
>>>
>>> while global_step < NUM_STEPS:
... with jax.profiler.StepTraceContext("train", step_num=global_step):
... with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
... train_step()
... global_step += 1
@ -79,14 +163,23 @@ class StepTraceContext(TraceContext):
super().__init__(name, _r=1, **kwargs)
def trace_function(func: Callable, name: str = None, **kwargs):
# TODO: remove this sometime after jax 0.1.11 is released
class StepTraceContext(StepTraceAnnotation):
def __init__(self, *args, **kwargs):
warnings.warn(
"StepTraceContext has been renamed to StepTraceAnnotation. This alias "
"will eventually be removed; please update your code.")
super().__init__(*args, **kwargs)
def annotate_function(func: Callable, name: str = None, **kwargs):
"""Decorator that generates a trace event for the execution of a function.
For example:
>>> import jax, jax.numpy as jnp
>>>
>>> @jax.profiler.trace_function
>>> @jax.profiler.annotate_function
>>> def f(x):
... return jnp.dot(x, x.T).block_until_ready()
>>>
@ -111,12 +204,21 @@ def trace_function(func: Callable, name: str = None, **kwargs):
name = name or func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
with TraceContext(name, **kwargs):
with TraceAnnotation(name, **kwargs):
return func(*args, **kwargs)
return wrapper
return wrapper
# TODO: remove this sometime after jax 0.1.11 is released
def trace_function(*args, **kwargs):
warnings.warn(
"trace_function has been renamed to annotate_function. This alias "
"will eventually be removed; please update your code.")
return annotate_function(*args, **kwargs)
def device_memory_profile(backend: Optional[str] = None) -> bytes:
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.

View File

@ -14,10 +14,16 @@
# flake8: noqa: F401
from jax._src.profiler import (
StepTraceAnnotation,
StepTraceContext,
TraceAnnotation,
TraceContext,
device_memory_profile,
save_device_memory_profile,
start_server,
start_trace,
stop_trace,
trace,
annotate_function,
trace_function,
)

View File

@ -16,6 +16,7 @@ from functools import partial
import glob
import os
import shutil
import tempfile
import threading
import unittest
from absl.testing import absltest
@ -56,23 +57,69 @@ class ProfilerTest(unittest.TestCase):
jax.profiler.start_server(port=port)
del port
def testTraceContext(self):
def testProgrammaticProfiling(self):
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
jnp.ones(jax.local_device_count()))
jax.profiler.stop_trace()
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
recursive=True)
self.assertEqual(len(proto_path), 1)
with open(proto_path[0], "rb") as f:
proto = f.read()
# Sanity check that serialized proto contains host and device traces
# without deserializing.
self.assertIn(b"/host:CPU", proto)
if jtu.device_under_test() == "tpu":
self.assertIn(b"/device:TPU", proto)
def testProgrammaticProfilingErrors(self):
with self.assertRaisesRegex(RuntimeError, "No profile started"):
jax.profiler.stop_trace()
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)
with self.assertRaisesRegex(RuntimeError,
"Profile has already been started. Only one "
"profile may be run at a time."):
jax.profiler.start_trace(tmpdir)
def testProgrammaticProfilingContextManager(self):
with tempfile.TemporaryDirectory() as tmpdir:
with jax.profiler.trace(tmpdir):
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
jnp.ones(jax.local_device_count()))
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
recursive=True)
self.assertEqual(len(proto_path), 1)
with open(proto_path[0], "rb") as f:
proto = f.read()
# Sanity check that serialized proto contains host and device traces
# without deserializing.
self.assertIn(b"/host:CPU", proto)
if jtu.device_under_test() == "tpu":
self.assertIn(b"/device:TPU", proto)
def testTraceAnnotation(self):
x = 3
with jax.profiler.TraceContext("mycontext"):
with jax.profiler.TraceAnnotation("mycontext"):
x = x + 2
def testTraceFunction(self):
@jax.profiler.trace_function
@jax.profiler.annotate_function
def f(x):
return x + 2
self.assertEqual(f(7), 9)
@partial(jax.profiler.trace_function, name="aname")
@partial(jax.profiler.annotate_function, name="aname")
def g(x):
return x + 2
self.assertEqual(g(7), 9)
@partial(jax.profiler.trace_function, name="aname", akwarg="hello")
@partial(jax.profiler.annotate_function, name="aname", akwarg="hello")
def h(x):
return x + 2
self.assertEqual(h(7), 9)
@ -96,7 +143,7 @@ class ProfilerTest(unittest.TestCase):
worker_start.set()
x = jnp.ones((1000, 1000))
while True:
with jax.profiler.TraceContext("atracecontext"):
with jax.profiler.TraceAnnotation("atraceannotation"):
jnp.dot(x, x.T).block_until_ready()
if self.profile_done:
break