Add an option to create a perfetto link in the JAX profiler

This commit is contained in:
Sharad Vikram 2022-05-25 16:01:16 -07:00
parent b80d7195f6
commit 76669835ba
4 changed files with 128 additions and 4 deletions

View File

@ -32,6 +32,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.numpy.ldexp` no longer silently promotes all inputs to float64,
instead it promotes to float32 for integer inputs of size int32 or smaller
({jax-issue}`#10921`).
* Add a `create_perfetto_link` option to {func}`jax.profiler.start_trace` and
{func}`jax.profiler.start_trace`. When used, the profiler will generate a
link to the Perfetto UI to view the trace.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

BIN
docs/_static/perfetto.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

View File

@ -1,5 +1,46 @@
# Profiling JAX programs
## Viewing program traces with Perfetto
We can use the JAX profiler to generate traces of a JAX program that can be
visualized using the [Perfetto visualizer](https://ui.perfetto.dev). Currently,
this method blocks the program until a link is clicked and the Perfetto UI loads
the trace. If you wish to get profiling information without any interaction,
check out the the Tensorboard profiler below.
```python
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
```
After this computation is done, the program will prompt you to open a link to
`ui.perfetto.dev`. When you open the link, the Perfetto UI will load the trace
file and open a visualizer.
![Perfetto trace viewer](_static/perfetto.png)
Program execution will continue after loading the link. The link is no longer
valid after opening once, but it will redirect to a new URL that remains valid.
You can then click the "Share" button in the Perfetto UI to create a permalink
to the trace that can be shared with others.
### Remote profiling
When profiling code that is running remotely (for example on a hosted VM),
you need to establish an SSH tunnel on port 9001 for the link to work. You can
do that with this command:
```bash
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
```
or if you're using Google Cloud:
```bash
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
```
## TensorBoard profiling
[TensorBoard's

View File

@ -14,10 +14,18 @@
from contextlib import contextmanager
from functools import wraps
import glob
import gzip
import http.server
import json
import os
import socketserver
import threading
from typing import Callable, Optional
import warnings
from typing import Callable, Optional
from absl import logging
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
@ -43,12 +51,13 @@ class _ProfileState:
def __init__(self):
self.profile_session = None
self.log_dir = None
self.create_perfetto_link = False
self.lock = threading.Lock()
_profile_state = _ProfileState()
def start_trace(log_dir):
def start_trace(log_dir, create_perfetto_link: bool = False):
"""Starts a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
@ -64,14 +73,79 @@ def start_trace(log_dir):
Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
create_perfetto_link: A boolean which, if true, creates and prints link to
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
block until the link is opened and Perfetto loads the trace.
"""
with _profile_state.lock:
if _profile_state.profile_session is not None:
raise RuntimeError("Profile has already been started. "
"Only one profile may be run at a time.")
_profile_state.profile_session = xla_client.profiler.ProfilerSession()
_profile_state.create_perfetto_link = create_perfetto_link
_profile_state.log_dir = log_dir
def _write_perfetto_trace_file(log_dir):
# Navigate to folder with the latest trace dump to find `trace.json.jz`
curr_path = os.path.abspath(log_dir)
root_trace_folder = os.path.join(curr_path, "plugins", "profile")
trace_folders = [os.path.join(root_trace_folder, trace_folder) for
trace_folder in os.listdir(root_trace_folder)]
latest_folder = max(trace_folders, key=os.path.getmtime)
trace_jsons = glob.glob(os.path.join(latest_folder, "*.trace.json.gz"))
if len(trace_jsons) != 1:
raise ValueError(f"Invalid trace folder: {latest_folder}")
trace_json, = trace_jsons
logging.info("Loading trace.json.gz and removing its metadata...")
# Perfetto doesn't like the `metadata` field in `trace.json` so we remove
# it.
# TODO(sharadmv): speed this up by updating the generated `trace.json`
# to not include metadata if possible.
with gzip.open(trace_json, "rb") as fp:
trace = json.load(fp)
del trace["metadata"]
filename = "perfetto_trace.json.gz"
perfetto_trace = os.path.join(latest_folder, filename)
logging.info("Writing perfetto_trace.json.gz...")
with gzip.open(perfetto_trace, "w") as fp:
fp.write(json.dumps(trace).encode("utf-8"))
return perfetto_trace
class _PerfettoServer(http.server.SimpleHTTPRequestHandler):
"""Handles requests from `ui.perfetto.dev` for the `trace.json`"""
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
return super().end_headers()
def do_GET(self):
self.server.last_request = self.path
return super().do_GET()
def do_POST(self):
self.send_error(404, "File not found")
def _host_perfetto_trace_file(log_dir):
# ui.perfetto.dev looks for files hosted on `127.0.0.1:9001`. We set up a
# TCP server that is hosting the `perfetto_trace.json.gz` file.
port = 9001
abs_filename = _write_perfetto_trace_file(log_dir)
orig_directory = os.path.abspath(os.getcwd())
directory, filename = os.path.split(abs_filename)
try:
os.chdir(directory)
socketserver.TCPServer.allow_reuse_address = True
with socketserver.TCPServer(('127.0.0.1', port), _PerfettoServer) as httpd:
url = f"https://ui.perfetto.dev/#!/?url=http://127.0.0.1:{port}/{filename}'"
print(f"Open URL in browser: {url}")
# Once ui.perfetto.dev acquires trace.json from this server we can close
# it down.
while httpd.__dict__.get('last_request') != '/' + filename:
httpd.handle_request()
finally:
os.chdir(orig_directory)
def stop_trace():
"""Stops the currently-running profiler trace.
@ -83,12 +157,15 @@ def stop_trace():
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
_profile_state.profile_session.stop_and_export(_profile_state.log_dir)
if _profile_state.create_perfetto_link:
_host_perfetto_trace_file(_profile_state.log_dir)
_profile_state.profile_session = None
_profile_state.create_perfetto_link = False
_profile_state.log_dir = None
@contextmanager
def trace(log_dir):
def trace(log_dir, create_perfetto_link=False):
"""Context manager to take a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
@ -103,8 +180,11 @@ def trace(log_dir):
Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
create_perfetto_link: A boolean which, if true, creates and prints link to
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
block until the link is opened and Perfetto loads the trace.
"""
start_trace(log_dir)
start_trace(log_dir, create_perfetto_link)
try:
yield
finally: