diff --git a/CHANGELOG.md b/CHANGELOG.md index fc137bde6..f097bff61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * {func}`jax.numpy.ldexp` no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller ({jax-issue}`#10921`). + * Add a `create_perfetto_link` option to {func}`jax.profiler.start_trace` and + {func}`jax.profiler.start_trace`. When used, the profiler will generate a + link to the Perfetto UI to view the trace. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/docs/_static/perfetto.png b/docs/_static/perfetto.png new file mode 100644 index 000000000..0df1b6ebe Binary files /dev/null and b/docs/_static/perfetto.png differ diff --git a/docs/profiling.md b/docs/profiling.md index a8c2ea0f6..f04156cfe 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,5 +1,46 @@ # Profiling JAX programs +## Viewing program traces with Perfetto + +We can use the JAX profiler to generate traces of a JAX program that can be +visualized using the [Perfetto visualizer](https://ui.perfetto.dev). Currently, +this method blocks the program until a link is clicked and the Perfetto UI loads +the trace. If you wish to get profiling information without any interaction, +check out the the Tensorboard profiler below. + +```python +with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): + # Run the operations to be profiled + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (5000, 5000)) + y = x @ x + y.block_until_ready() +``` + +After this computation is done, the program will prompt you to open a link to +`ui.perfetto.dev`. When you open the link, the Perfetto UI will load the trace +file and open a visualizer. + +![Perfetto trace viewer](_static/perfetto.png) + +Program execution will continue after loading the link. The link is no longer +valid after opening once, but it will redirect to a new URL that remains valid. +You can then click the "Share" button in the Perfetto UI to create a permalink +to the trace that can be shared with others. + +### Remote profiling + +When profiling code that is running remotely (for example on a hosted VM), +you need to establish an SSH tunnel on port 9001 for the link to work. You can +do that with this command: +```bash +$ ssh -L 9001:127.0.0.1:9001 @ +``` +or if you're using Google Cloud: +```bash +$ gcloud compute ssh -- -L 9001:127.0.0.1:9001 +``` + ## TensorBoard profiling [TensorBoard's diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 5ca290b27..225f52ad6 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -14,10 +14,18 @@ from contextlib import contextmanager from functools import wraps +import glob +import gzip +import http.server +import json +import os +import socketserver import threading -from typing import Callable, Optional import warnings +from typing import Callable, Optional + +from absl import logging from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -43,12 +51,13 @@ class _ProfileState: def __init__(self): self.profile_session = None self.log_dir = None + self.create_perfetto_link = False self.lock = threading.Lock() _profile_state = _ProfileState() -def start_trace(log_dir): +def start_trace(log_dir, create_perfetto_link: bool = False): """Starts a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python @@ -64,14 +73,79 @@ def start_trace(log_dir): Args: log_dir: The directory to save the profiler trace to (usually the TensorBoard log directory). + create_perfetto_link: A boolean which, if true, creates and prints link to + the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will + block until the link is opened and Perfetto loads the trace. """ with _profile_state.lock: if _profile_state.profile_session is not None: raise RuntimeError("Profile has already been started. " "Only one profile may be run at a time.") _profile_state.profile_session = xla_client.profiler.ProfilerSession() + _profile_state.create_perfetto_link = create_perfetto_link _profile_state.log_dir = log_dir +def _write_perfetto_trace_file(log_dir): + # Navigate to folder with the latest trace dump to find `trace.json.jz` + curr_path = os.path.abspath(log_dir) + root_trace_folder = os.path.join(curr_path, "plugins", "profile") + trace_folders = [os.path.join(root_trace_folder, trace_folder) for + trace_folder in os.listdir(root_trace_folder)] + latest_folder = max(trace_folders, key=os.path.getmtime) + trace_jsons = glob.glob(os.path.join(latest_folder, "*.trace.json.gz")) + if len(trace_jsons) != 1: + raise ValueError(f"Invalid trace folder: {latest_folder}") + trace_json, = trace_jsons + + logging.info("Loading trace.json.gz and removing its metadata...") + # Perfetto doesn't like the `metadata` field in `trace.json` so we remove + # it. + # TODO(sharadmv): speed this up by updating the generated `trace.json` + # to not include metadata if possible. + with gzip.open(trace_json, "rb") as fp: + trace = json.load(fp) + del trace["metadata"] + filename = "perfetto_trace.json.gz" + perfetto_trace = os.path.join(latest_folder, filename) + logging.info("Writing perfetto_trace.json.gz...") + with gzip.open(perfetto_trace, "w") as fp: + fp.write(json.dumps(trace).encode("utf-8")) + return perfetto_trace + +class _PerfettoServer(http.server.SimpleHTTPRequestHandler): + """Handles requests from `ui.perfetto.dev` for the `trace.json`""" + + def end_headers(self): + self.send_header('Access-Control-Allow-Origin', '*') + return super().end_headers() + + def do_GET(self): + self.server.last_request = self.path + return super().do_GET() + + def do_POST(self): + self.send_error(404, "File not found") + +def _host_perfetto_trace_file(log_dir): + # ui.perfetto.dev looks for files hosted on `127.0.0.1:9001`. We set up a + # TCP server that is hosting the `perfetto_trace.json.gz` file. + port = 9001 + abs_filename = _write_perfetto_trace_file(log_dir) + orig_directory = os.path.abspath(os.getcwd()) + directory, filename = os.path.split(abs_filename) + try: + os.chdir(directory) + socketserver.TCPServer.allow_reuse_address = True + with socketserver.TCPServer(('127.0.0.1', port), _PerfettoServer) as httpd: + url = f"https://ui.perfetto.dev/#!/?url=http://127.0.0.1:{port}/{filename}'" + print(f"Open URL in browser: {url}") + + # Once ui.perfetto.dev acquires trace.json from this server we can close + # it down. + while httpd.__dict__.get('last_request') != '/' + filename: + httpd.handle_request() + finally: + os.chdir(orig_directory) def stop_trace(): """Stops the currently-running profiler trace. @@ -83,12 +157,15 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") _profile_state.profile_session.stop_and_export(_profile_state.log_dir) + if _profile_state.create_perfetto_link: + _host_perfetto_trace_file(_profile_state.log_dir) _profile_state.profile_session = None + _profile_state.create_perfetto_link = False _profile_state.log_dir = None @contextmanager -def trace(log_dir): +def trace(log_dir, create_perfetto_link=False): """Context manager to take a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python @@ -103,8 +180,11 @@ def trace(log_dir): Args: log_dir: The directory to save the profiler trace to (usually the TensorBoard log directory). + create_perfetto_link: A boolean which, if true, creates and prints link to + the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will + block until the link is opened and Perfetto loads the trace. """ - start_trace(log_dir) + start_trace(log_dir, create_perfetto_link) try: yield finally: