mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add an option to create a perfetto link in the JAX profiler
This commit is contained in:
parent
b80d7195f6
commit
76669835ba
@ -32,6 +32,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* {func}`jax.numpy.ldexp` no longer silently promotes all inputs to float64,
|
||||
instead it promotes to float32 for integer inputs of size int32 or smaller
|
||||
({jax-issue}`#10921`).
|
||||
* Add a `create_perfetto_link` option to {func}`jax.profiler.start_trace` and
|
||||
{func}`jax.profiler.start_trace`. When used, the profiler will generate a
|
||||
link to the Perfetto UI to view the trace.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
BIN
docs/_static/perfetto.png
vendored
Normal file
BIN
docs/_static/perfetto.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 96 KiB |
@ -1,5 +1,46 @@
|
||||
# Profiling JAX programs
|
||||
|
||||
## Viewing program traces with Perfetto
|
||||
|
||||
We can use the JAX profiler to generate traces of a JAX program that can be
|
||||
visualized using the [Perfetto visualizer](https://ui.perfetto.dev). Currently,
|
||||
this method blocks the program until a link is clicked and the Perfetto UI loads
|
||||
the trace. If you wish to get profiling information without any interaction,
|
||||
check out the the Tensorboard profiler below.
|
||||
|
||||
```python
|
||||
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
|
||||
# Run the operations to be profiled
|
||||
key = jax.random.PRNGKey(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
```
|
||||
|
||||
After this computation is done, the program will prompt you to open a link to
|
||||
`ui.perfetto.dev`. When you open the link, the Perfetto UI will load the trace
|
||||
file and open a visualizer.
|
||||
|
||||

|
||||
|
||||
Program execution will continue after loading the link. The link is no longer
|
||||
valid after opening once, but it will redirect to a new URL that remains valid.
|
||||
You can then click the "Share" button in the Perfetto UI to create a permalink
|
||||
to the trace that can be shared with others.
|
||||
|
||||
### Remote profiling
|
||||
|
||||
When profiling code that is running remotely (for example on a hosted VM),
|
||||
you need to establish an SSH tunnel on port 9001 for the link to work. You can
|
||||
do that with this command:
|
||||
```bash
|
||||
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
|
||||
```
|
||||
or if you're using Google Cloud:
|
||||
```bash
|
||||
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
|
||||
```
|
||||
|
||||
## TensorBoard profiling
|
||||
|
||||
[TensorBoard's
|
||||
|
@ -14,10 +14,18 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
import glob
|
||||
import gzip
|
||||
import http.server
|
||||
import json
|
||||
import os
|
||||
import socketserver
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
import warnings
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from absl import logging
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -43,12 +51,13 @@ class _ProfileState:
|
||||
def __init__(self):
|
||||
self.profile_session = None
|
||||
self.log_dir = None
|
||||
self.create_perfetto_link = False
|
||||
self.lock = threading.Lock()
|
||||
|
||||
_profile_state = _ProfileState()
|
||||
|
||||
|
||||
def start_trace(log_dir):
|
||||
def start_trace(log_dir, create_perfetto_link: bool = False):
|
||||
"""Starts a profiler trace.
|
||||
|
||||
The trace will capture CPU, GPU, and/or TPU activity, including Python
|
||||
@ -64,14 +73,79 @@ def start_trace(log_dir):
|
||||
Args:
|
||||
log_dir: The directory to save the profiler trace to (usually the
|
||||
TensorBoard log directory).
|
||||
create_perfetto_link: A boolean which, if true, creates and prints link to
|
||||
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
|
||||
block until the link is opened and Perfetto loads the trace.
|
||||
"""
|
||||
with _profile_state.lock:
|
||||
if _profile_state.profile_session is not None:
|
||||
raise RuntimeError("Profile has already been started. "
|
||||
"Only one profile may be run at a time.")
|
||||
_profile_state.profile_session = xla_client.profiler.ProfilerSession()
|
||||
_profile_state.create_perfetto_link = create_perfetto_link
|
||||
_profile_state.log_dir = log_dir
|
||||
|
||||
def _write_perfetto_trace_file(log_dir):
|
||||
# Navigate to folder with the latest trace dump to find `trace.json.jz`
|
||||
curr_path = os.path.abspath(log_dir)
|
||||
root_trace_folder = os.path.join(curr_path, "plugins", "profile")
|
||||
trace_folders = [os.path.join(root_trace_folder, trace_folder) for
|
||||
trace_folder in os.listdir(root_trace_folder)]
|
||||
latest_folder = max(trace_folders, key=os.path.getmtime)
|
||||
trace_jsons = glob.glob(os.path.join(latest_folder, "*.trace.json.gz"))
|
||||
if len(trace_jsons) != 1:
|
||||
raise ValueError(f"Invalid trace folder: {latest_folder}")
|
||||
trace_json, = trace_jsons
|
||||
|
||||
logging.info("Loading trace.json.gz and removing its metadata...")
|
||||
# Perfetto doesn't like the `metadata` field in `trace.json` so we remove
|
||||
# it.
|
||||
# TODO(sharadmv): speed this up by updating the generated `trace.json`
|
||||
# to not include metadata if possible.
|
||||
with gzip.open(trace_json, "rb") as fp:
|
||||
trace = json.load(fp)
|
||||
del trace["metadata"]
|
||||
filename = "perfetto_trace.json.gz"
|
||||
perfetto_trace = os.path.join(latest_folder, filename)
|
||||
logging.info("Writing perfetto_trace.json.gz...")
|
||||
with gzip.open(perfetto_trace, "w") as fp:
|
||||
fp.write(json.dumps(trace).encode("utf-8"))
|
||||
return perfetto_trace
|
||||
|
||||
class _PerfettoServer(http.server.SimpleHTTPRequestHandler):
|
||||
"""Handles requests from `ui.perfetto.dev` for the `trace.json`"""
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
return super().end_headers()
|
||||
|
||||
def do_GET(self):
|
||||
self.server.last_request = self.path
|
||||
return super().do_GET()
|
||||
|
||||
def do_POST(self):
|
||||
self.send_error(404, "File not found")
|
||||
|
||||
def _host_perfetto_trace_file(log_dir):
|
||||
# ui.perfetto.dev looks for files hosted on `127.0.0.1:9001`. We set up a
|
||||
# TCP server that is hosting the `perfetto_trace.json.gz` file.
|
||||
port = 9001
|
||||
abs_filename = _write_perfetto_trace_file(log_dir)
|
||||
orig_directory = os.path.abspath(os.getcwd())
|
||||
directory, filename = os.path.split(abs_filename)
|
||||
try:
|
||||
os.chdir(directory)
|
||||
socketserver.TCPServer.allow_reuse_address = True
|
||||
with socketserver.TCPServer(('127.0.0.1', port), _PerfettoServer) as httpd:
|
||||
url = f"https://ui.perfetto.dev/#!/?url=http://127.0.0.1:{port}/{filename}'"
|
||||
print(f"Open URL in browser: {url}")
|
||||
|
||||
# Once ui.perfetto.dev acquires trace.json from this server we can close
|
||||
# it down.
|
||||
while httpd.__dict__.get('last_request') != '/' + filename:
|
||||
httpd.handle_request()
|
||||
finally:
|
||||
os.chdir(orig_directory)
|
||||
|
||||
def stop_trace():
|
||||
"""Stops the currently-running profiler trace.
|
||||
@ -83,12 +157,15 @@ def stop_trace():
|
||||
if _profile_state.profile_session is None:
|
||||
raise RuntimeError("No profile started")
|
||||
_profile_state.profile_session.stop_and_export(_profile_state.log_dir)
|
||||
if _profile_state.create_perfetto_link:
|
||||
_host_perfetto_trace_file(_profile_state.log_dir)
|
||||
_profile_state.profile_session = None
|
||||
_profile_state.create_perfetto_link = False
|
||||
_profile_state.log_dir = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace(log_dir):
|
||||
def trace(log_dir, create_perfetto_link=False):
|
||||
"""Context manager to take a profiler trace.
|
||||
|
||||
The trace will capture CPU, GPU, and/or TPU activity, including Python
|
||||
@ -103,8 +180,11 @@ def trace(log_dir):
|
||||
Args:
|
||||
log_dir: The directory to save the profiler trace to (usually the
|
||||
TensorBoard log directory).
|
||||
create_perfetto_link: A boolean which, if true, creates and prints link to
|
||||
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
|
||||
block until the link is opened and Perfetto loads the trace.
|
||||
"""
|
||||
start_trace(log_dir)
|
||||
start_trace(log_dir, create_perfetto_link)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
Loading…
x
Reference in New Issue
Block a user