Add collect_profile script

This commit is contained in:
Sharad Vikram 2022-06-02 22:15:53 -07:00
parent 401812196f
commit 143ed40a78
5 changed files with 169 additions and 1 deletions

View File

@ -147,4 +147,4 @@ jobs:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
run: |
pytest -n 1 --tb=short docs
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py

View File

@ -40,6 +40,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
it.
* Added {func}`jax.random.generalized_normal`.
* Added {func}`jax.random.ball`.
* Added a `python -m jax.collect_profile` script to manually capture program
traces as an alternative to the Tensorboard UI.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -41,6 +41,29 @@ or if you're using Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
```
### Manual capture
Instead of capturing traces programmatically using `jax.profiler.trace`, you can
instead start a profiling server in the script of interest by calling
`jax.profiler.start_server(<port>)`. If you only need the profiler server to be
active for a portion of your script, you can shut it down by calling
`jax.profiler.stop_server()`.
Once the script is running and after the profiler server has started, we can
manually capture an trace by running:
```bash
$ python -m jax.collect_profile <port> <duration_in_ms>
```
By default, the resulting trace information is dumped into a temporary directory
but this can be overridden by passing in `--log_dir=<directory of choice>`.
Also, by default, the program will prompt you to open a link to
`ui.perfetto.dev`. When you open the link, the Perfetto UI will load the trace
file and open a visualizer. This feature is disabled by passing in
`--no_perfetto_link` into the command. Alternatively, you can also point
Tensorboard to the `log_dir` to analyze the trace (see the
"Tensorboard Profiling" section below).
## TensorBoard profiling
[TensorBoard's

107
jax/collect_profile.py Normal file
View File

@ -0,0 +1,107 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gzip
import os
import pathlib
import tempfile
from typing import Optional
# pytype: disable=import-error
import jax
try:
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.profiler import profiler_client
except ImportError:
raise ImportError("This script requires `tensorflow` to be installed.")
try:
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
except ImportError:
raise ImportError(
"This script requires `tensorboard_plugin_profile` to be installed.")
# pytype: enable=import-error
_DESCRIPTION = """
To profile running JAX programs, you first need to start the profiler server
in the program of interest. You can do this via
`jax.profiler.start_server(<port>)`. Once the program is running and the
profiler server has started, you can run `collect_profile` to trace the execution
for a provided duration. The trace file will be dumped into a directory
(determined by `--log_dir`) and by default, a Perfetto UI link will be generated
to view the resulting trace.
"""
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--log_dir", default=None,
help=("Directory to store log files. "
"Uses a temporary directory if none provided."),
type=str)
parser.add_argument("port", help="Port to collect trace", type=int)
parser.add_argument("duration_in_ms",
help="Duration to collect trace in milliseconds", type=int)
parser.add_argument("--no_perfetto_link",
help="Disable creating a perfetto link",
action="store_true")
parser.add_argument("--host", default="127.0.0.1",
help="Host to collect trace. Defaults to 127.0.0.1",
type=str)
parser.add_argument("--host_tracer_level", default=2,
help="Profiler host tracer level", type=int)
parser.add_argument("--device_tracer_level", default=1,
help="Profiler device tracer level", type=int)
parser.add_argument("--python_tracer_level", default=1,
help="Profiler Python tracer level", type=int)
def collect_profile(port: int, duration_in_ms: int, host: str,
log_dir: Optional[str], host_tracer_level: int,
device_tracer_level: int, python_tracer_level: int,
no_perfetto_link: bool):
options = profiler.ProfilerOptions(
host_tracer_level=host_tracer_level,
device_tracer_level=device_tracer_level,
python_tracer_level=python_tracer_level,
)
log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp())
profiler_client.trace(
f"{host}:{port}",
str(log_dir_),
duration_in_ms,
options=options)
print(f"Dumped profiling information in: {log_dir_}")
# The profiler dumps `xplane.pb` to the logging directory. To upload it to
# the Perfetto trace viewer, we need to convert it to a `trace.json` file.
# We do this by first finding the `xplane.pb` file, then passing it into
# tensorflow_profile_plugin's `xplane` conversion function.
curr_path = log_dir_.resolve()
root_trace_folder = curr_path / "plugins" / "profile"
trace_folders = [root_trace_folder / trace_folder for trace_folder
in root_trace_folder.iterdir()]
latest_folder = max(trace_folders, key=os.path.getmtime)
xplane = next(latest_folder.glob("*.xplane.pb"))
result = convert.xspace_to_tool_data([xplane], "trace_viewer^", None)
with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp:
fp.write(result.encode("utf-8"))
if not no_perfetto_link:
jax._src.profiler._host_perfetto_trace_file(str(log_dir_))
def main(args):
collect_profile(args.port, args.duration_in_ms, args.host, args.log_dir,
args.host_tracer_level, args.device_tracer_level,
args.python_tracer_level, args.no_perfetto_link)
if __name__ == "__main__":
main(parser.parse_args())

View File

@ -18,6 +18,7 @@ import os
import shutil
import tempfile
import threading
import time
import unittest
from absl.testing import absltest
@ -39,6 +40,14 @@ except ImportError:
profiler_client = None
tf_profiler = None
TBP_ENABLED = False
try:
import tensorboard_plugin_profile
del tensorboard_plugin_profile
TBP_ENABLED = True
except ImportError:
pass
config.parse_flags_with_absl()
@ -205,5 +214,32 @@ class ProfilerTest(unittest.TestCase):
thread_worker.join(120)
self._check_xspace_pb_exist(logdir)
@unittest.skipIf(
not (portpicker and profiler_client and tf_profiler and TBP_ENABLED),
"Test requires tensorflow.profiler, portpicker and "
"tensorboard_profile_plugin")
def test_remote_profiler(self):
port = portpicker.pick_unused_port()
logdir = absltest.get_default_test_tmpdir()
# Remove any existing log files.
shutil.rmtree(logdir, ignore_errors=True)
def on_profile():
os.system(
f"python -m jax.collect_profile {port} 500 --log_dir {logdir} "
"--no_perfetto_link")
thread_profiler = threading.Thread(
target=on_profile, args=())
thread_profiler.start()
jax.profiler.start_server(port)
start_time = time.time()
y = jnp.zeros((5, 5))
while time.time() - start_time < 3:
y = jnp.dot(y, y)
jax.profiler.stop_server()
thread_profiler.join()
self._check_xspace_pb_exist(logdir)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())