diff --git a/CHANGELOG.md b/CHANGELOG.md index 2dc46bf81..7c383aa73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,17 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.12 (unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master). - +* New features + * New profiling APIs: {func}`jax.profiler.start_trace`, + {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace` * Breaking changes: * The minimum jaxlib version is now 0.1.64. + * Some profiler APIs names have been changed. There are still aliases, so this + should not break existing code, but the aliases will eventually be removed + so please change your code. + * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation` + * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation` + * `trace_function` --> {func}`~jax.profiler.annotate_function` ## jaxlib 0.1.65 (unreleased) diff --git a/docs/jax.profiler.rst b/docs/jax.profiler.rst index cbd6584bf..7e99fc388 100644 --- a/docs/jax.profiler.rst +++ b/docs/jax.profiler.rst @@ -15,8 +15,12 @@ features. :toctree: _autosummary start_server - trace_function - TraceContext + start_trace + stop_trace + trace + annotate_function + TraceAnnotation + StepTraceAnnotation Device memory profiling @@ -30,4 +34,3 @@ profiling features. device_memory_profile save_device_memory_profile - \ No newline at end of file diff --git a/docs/profiling.md b/docs/profiling.md index d7244bd91..1fe324b45 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -18,7 +18,69 @@ Install specific nightly versions of TensorBoard, TensorBoard profiler, TensorFl pip install --upgrade tb-nightly==2.5.0a20201203 tbp-nightly==2.4.0a20201203 tf-nightly==2.5.0.dev20201203 tensorboard-plugin-wit==1.7.0 ``` -### Usage +### Programmatic capture + +You can instrument your code to capture a profiler trace via the +{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` +methods. Call {func}`~jax.profiler.start_trace` with the directory to write +trace files to. This should be the same `--logdir` directory used to start +TensorBoard. Then, you can use TensorBoard to view the traces. + +For example, to take a profiler trace: + +```python +import jax + +jax.profiler.start_trace("/tmp/tensorboard") + +# Run the operations to be profiled +key = jax.random.PRNGKey(0) +x = jax.random.normal(key, (5000, 5000)) +y = x @ x +y.block_until_ready() + +jax.profiler.stop_trace() +``` + +Note the {func}`block_until_ready` call. We use this to make sure on-device +execution is captured by the trace. See {ref}`async-dispatch` for details on why +this is necessary. + +You can also use the {func}`jax.profiler.trace` context manager as an +alternative to `start_trace` and `stop_trace`: + +```python +import jax + +with jax.profiler.trace(): + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (5000, 5000)) + y = x @ x + y.block_until_ready() +``` + +To view the trace, first start TensorBoard if you haven't already: + +```shell +$ tensorboard --logdir /tmp/tensorboard +Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all +TensorBoard 2.5.0a20210304 at http://localhost:6006/ (Press CTRL+C to quit) +``` + +You should be able to load TensorBoard at in this +example. You can specify a different port with the `--port` flag. See +{ref}`remote_profiling` below if running JAX on a remote server. + +Then, either select "Profile" in the upper-right dropdown menu, or go directly +to . Available traces appear in the "Runs" +dropdown menu on the left. Select the run you're interested in, and then under +"Tools", select "trace_viewer". You should now see a timeline of the +execution. You can use the WASD keys to navigate the trace, and click or drag to +select events to see more details at the bottom. See [these TensorFlow +docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) +for more details on using the trace viewer. + +### Manual capture via TensorBoard The following are instructions for capturing a manually-triggered N-second trace from a running program. @@ -43,7 +105,7 @@ from a running program. This starts the profiler server that TensorBoard connects to. The profiler server must be running before you move on to the next step. It will remain - alive and listening until the object returned by `start_server()` is + alive and listening until the object returned by `start_server()` is destroyed. If you'd like to profile a snippet of a long-running program (e.g. a long @@ -76,10 +138,12 @@ from a running program. docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) for more details on using the trace viewer.

-1. By default, the events in the trace viewer are mostly low-level internal JAX - functions. You can add your own events and functions by using - {func}`jax.profiler.TraceContext` and {func}`jax.profiler.trace_function` in - your code and capturing a new profile. +### Adding custom trace events + +By default, the events in the trace viewer are mostly low-level internal JAX +functions. You can add your own events and functions by using +{class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in +your code. ### Troubleshooting diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index b8a7fa265..192e4487a 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from functools import wraps +import threading from typing import Callable, Optional +import warnings from jax.lib import xla_bridge from jax.lib import xla_client @@ -24,8 +27,8 @@ def start_server(port: int): Using the "TensorFlow profiler" feature in `TensorBoard `_ 2.2 or newer, you can - connect to the profiler server and sample execution traces that show CPU and - GPU device activity. + connect to the profiler server and sample execution traces that show CPU, + GPU, and/or TPU device activity. Returns a profiler server object. The server remains alive and listening until the server object is destroyed. @@ -33,8 +36,80 @@ def start_server(port: int): return xla_client.profiler.start_server(port) -class TraceContext(xla_client.profiler.TraceMe): - """Context manager generates a trace event in the profiler. +class _ProfileState(object): + def __init__(self): + self.profile_session = None + self.log_dir = None + self.lock = threading.Lock() + +_profile_state = _ProfileState() + + +def start_trace(log_dir): + """Starts a profiler trace. + + The trace will capture CPU, GPU, and/or TPU activity, including Python + functions and JAX on-device operations. Use ``stop_trace()`` to end the trace + and save the results to ``log_dir``. + + The resulting trace can be viewed with TensorBoard. Note that TensorBoard + doesn't need to be running when collecting the trace. + + Only once trace may be collected a time. A RuntimeError will be raised if + ``start_trace()`` is called while another trace is running. + + Args: + log_dir: The directory to save the profiler trace to (usually the + TensorBoard log directory). + """ + with _profile_state.lock: + if _profile_state.profile_session is not None: + raise RuntimeError("Profile has already been started. " + "Only one profile may be run at a time.") + _profile_state.profile_session = xla_client.profiler.ProfilerSession() + _profile_state.log_dir = log_dir + + +def stop_trace(): + """Stops the currently-running profiler trace. + + The trace will be saved to the ``log_dir`` passed to the corresponding + ``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started. + """ + with _profile_state.lock: + if _profile_state.profile_session is None: + raise RuntimeError("No profile started") + _profile_state.profile_session.stop_and_export(_profile_state.log_dir) + _profile_state.profile_session = None + _profile_state.log_dir = None + + +@contextmanager +def trace(log_dir): + """Context manager to take a profiler trace. + + The trace will capture CPU, GPU, and/or TPU activity, including Python + functions and JAX on-device operations. + + The resulting trace can be viewed with TensorBoard. Note that TensorBoard + doesn't need to be running when collecting the trace. + + Only once trace may be collected a time. A RuntimeError will be raised if a + trace is started while another trace is running. + + Args: + log_dir: The directory to save the profiler trace to (usually the + TensorBoard log directory). + """ + start_trace(log_dir) + try: + yield + finally: + stop_trace() + + +class TraceAnnotation(xla_client.profiler.TraceMe): + """Context manager that generates a trace event in the profiler. The trace event spans the duration of the code enclosed by the context. @@ -42,16 +117,25 @@ class TraceContext(xla_client.profiler.TraceMe): >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1000, 1000)) - >>> with jax.profiler.TraceContext("acontext"): + >>> with jax.profiler.TraceAnnotation("my_label"): ... jnp.dot(x, x.T).block_until_ready() - This will cause an "acontext" event to show up on the trace timeline if the - event occurs while the process is being traced by TensorBoard. + This will cause a "my_label" event to show up on the trace timeline if the + event occurs while the process is being traced. """ pass -class StepTraceContext(TraceContext): +# TODO: remove this sometime after jax 0.1.11 is released +class TraceContext(TraceAnnotation): + def __init__(self, *args, **kwargs): + warnings.warn( + "TraceContext has been renamed to TraceAnnotation. This alias " + "will eventually be removed; please update your code.") + super().__init__(*args, **kwargs) + + +class StepTraceAnnotation(TraceAnnotation): """Context manager that generates a step trace event in the profiler. The step trace event spans the duration of the code enclosed by the context. @@ -63,7 +147,7 @@ class StepTraceContext(TraceContext): >>> import jax >>> >>> while global_step < NUM_STEPS: - ... with jax.profiler.StepTraceContext("train", step_num=global_step): + ... with jax.profiler.StepTraceAnnotation("train", step_num=global_step): ... train_step() ... global_step += 1 @@ -79,14 +163,23 @@ class StepTraceContext(TraceContext): super().__init__(name, _r=1, **kwargs) -def trace_function(func: Callable, name: str = None, **kwargs): +# TODO: remove this sometime after jax 0.1.11 is released +class StepTraceContext(StepTraceAnnotation): + def __init__(self, *args, **kwargs): + warnings.warn( + "StepTraceContext has been renamed to StepTraceAnnotation. This alias " + "will eventually be removed; please update your code.") + super().__init__(*args, **kwargs) + + +def annotate_function(func: Callable, name: str = None, **kwargs): """Decorator that generates a trace event for the execution of a function. For example: >>> import jax, jax.numpy as jnp >>> - >>> @jax.profiler.trace_function + >>> @jax.profiler.annotate_function >>> def f(x): ... return jnp.dot(x, x.T).block_until_ready() >>> @@ -111,12 +204,21 @@ def trace_function(func: Callable, name: str = None, **kwargs): name = name or func.__name__ @wraps(func) def wrapper(*args, **kwargs): - with TraceContext(name, **kwargs): + with TraceAnnotation(name, **kwargs): return func(*args, **kwargs) return wrapper return wrapper +# TODO: remove this sometime after jax 0.1.11 is released +def trace_function(*args, **kwargs): + warnings.warn( + "trace_function has been renamed to annotate_function. This alias " + "will eventually be removed; please update your code.") + return annotate_function(*args, **kwargs) + + + def device_memory_profile(backend: Optional[str] = None) -> bytes: """Captures a JAX device memory profile as ``pprof``-format protocol buffer. diff --git a/jax/profiler.py b/jax/profiler.py index a66eb9f8e..e8e5db9e6 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -14,10 +14,16 @@ # flake8: noqa: F401 from jax._src.profiler import ( + StepTraceAnnotation, StepTraceContext, + TraceAnnotation, TraceContext, device_memory_profile, save_device_memory_profile, start_server, + start_trace, + stop_trace, + trace, + annotate_function, trace_function, ) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index c7aa297b1..21ceb632c 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -16,6 +16,7 @@ from functools import partial import glob import os import shutil +import tempfile import threading import unittest from absl.testing import absltest @@ -56,23 +57,69 @@ class ProfilerTest(unittest.TestCase): jax.profiler.start_server(port=port) del port - def testTraceContext(self): + def testProgrammaticProfiling(self): + with tempfile.TemporaryDirectory() as tmpdir: + jax.profiler.start_trace(tmpdir) + jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( + jnp.ones(jax.local_device_count())) + jax.profiler.stop_trace() + + proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"), + recursive=True) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Sanity check that serialized proto contains host and device traces + # without deserializing. + self.assertIn(b"/host:CPU", proto) + if jtu.device_under_test() == "tpu": + self.assertIn(b"/device:TPU", proto) + + def testProgrammaticProfilingErrors(self): + with self.assertRaisesRegex(RuntimeError, "No profile started"): + jax.profiler.stop_trace() + + with tempfile.TemporaryDirectory() as tmpdir: + jax.profiler.start_trace(tmpdir) + with self.assertRaisesRegex(RuntimeError, + "Profile has already been started. Only one " + "profile may be run at a time."): + jax.profiler.start_trace(tmpdir) + + def testProgrammaticProfilingContextManager(self): + with tempfile.TemporaryDirectory() as tmpdir: + with jax.profiler.trace(tmpdir): + jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( + jnp.ones(jax.local_device_count())) + + proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"), + recursive=True) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Sanity check that serialized proto contains host and device traces + # without deserializing. + self.assertIn(b"/host:CPU", proto) + if jtu.device_under_test() == "tpu": + self.assertIn(b"/device:TPU", proto) + + def testTraceAnnotation(self): x = 3 - with jax.profiler.TraceContext("mycontext"): + with jax.profiler.TraceAnnotation("mycontext"): x = x + 2 def testTraceFunction(self): - @jax.profiler.trace_function + @jax.profiler.annotate_function def f(x): return x + 2 self.assertEqual(f(7), 9) - @partial(jax.profiler.trace_function, name="aname") + @partial(jax.profiler.annotate_function, name="aname") def g(x): return x + 2 self.assertEqual(g(7), 9) - @partial(jax.profiler.trace_function, name="aname", akwarg="hello") + @partial(jax.profiler.annotate_function, name="aname", akwarg="hello") def h(x): return x + 2 self.assertEqual(h(7), 9) @@ -96,7 +143,7 @@ class ProfilerTest(unittest.TestCase): worker_start.set() x = jnp.ones((1000, 1000)) while True: - with jax.profiler.TraceContext("atracecontext"): + with jax.profiler.TraceAnnotation("atraceannotation"): jnp.dot(x, x.T).block_until_ready() if self.profile_done: break