mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add collect_profile script
This commit is contained in:
parent
401812196f
commit
143ed40a78
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -147,4 +147,4 @@ jobs:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
run: |
|
||||
pytest -n 1 --tb=short docs
|
||||
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization
|
||||
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py
|
||||
|
@ -40,6 +40,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
it.
|
||||
* Added {func}`jax.random.generalized_normal`.
|
||||
* Added {func}`jax.random.ball`.
|
||||
* Added a `python -m jax.collect_profile` script to manually capture program
|
||||
traces as an alternative to the Tensorboard UI.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -41,6 +41,29 @@ or if you're using Google Cloud:
|
||||
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
|
||||
```
|
||||
|
||||
### Manual capture
|
||||
|
||||
Instead of capturing traces programmatically using `jax.profiler.trace`, you can
|
||||
instead start a profiling server in the script of interest by calling
|
||||
`jax.profiler.start_server(<port>)`. If you only need the profiler server to be
|
||||
active for a portion of your script, you can shut it down by calling
|
||||
`jax.profiler.stop_server()`.
|
||||
|
||||
Once the script is running and after the profiler server has started, we can
|
||||
manually capture an trace by running:
|
||||
```bash
|
||||
$ python -m jax.collect_profile <port> <duration_in_ms>
|
||||
```
|
||||
|
||||
By default, the resulting trace information is dumped into a temporary directory
|
||||
but this can be overridden by passing in `--log_dir=<directory of choice>`.
|
||||
Also, by default, the program will prompt you to open a link to
|
||||
`ui.perfetto.dev`. When you open the link, the Perfetto UI will load the trace
|
||||
file and open a visualizer. This feature is disabled by passing in
|
||||
`--no_perfetto_link` into the command. Alternatively, you can also point
|
||||
Tensorboard to the `log_dir` to analyze the trace (see the
|
||||
"Tensorboard Profiling" section below).
|
||||
|
||||
## TensorBoard profiling
|
||||
|
||||
[TensorBoard's
|
||||
|
107
jax/collect_profile.py
Normal file
107
jax/collect_profile.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gzip
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# pytype: disable=import-error
|
||||
import jax
|
||||
try:
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
except ImportError:
|
||||
raise ImportError("This script requires `tensorflow` to be installed.")
|
||||
try:
|
||||
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"This script requires `tensorboard_plugin_profile` to be installed.")
|
||||
# pytype: enable=import-error
|
||||
|
||||
|
||||
_DESCRIPTION = """
|
||||
To profile running JAX programs, you first need to start the profiler server
|
||||
in the program of interest. You can do this via
|
||||
`jax.profiler.start_server(<port>)`. Once the program is running and the
|
||||
profiler server has started, you can run `collect_profile` to trace the execution
|
||||
for a provided duration. The trace file will be dumped into a directory
|
||||
(determined by `--log_dir`) and by default, a Perfetto UI link will be generated
|
||||
to view the resulting trace.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
||||
parser.add_argument("--log_dir", default=None,
|
||||
help=("Directory to store log files. "
|
||||
"Uses a temporary directory if none provided."),
|
||||
type=str)
|
||||
parser.add_argument("port", help="Port to collect trace", type=int)
|
||||
parser.add_argument("duration_in_ms",
|
||||
help="Duration to collect trace in milliseconds", type=int)
|
||||
parser.add_argument("--no_perfetto_link",
|
||||
help="Disable creating a perfetto link",
|
||||
action="store_true")
|
||||
parser.add_argument("--host", default="127.0.0.1",
|
||||
help="Host to collect trace. Defaults to 127.0.0.1",
|
||||
type=str)
|
||||
parser.add_argument("--host_tracer_level", default=2,
|
||||
help="Profiler host tracer level", type=int)
|
||||
parser.add_argument("--device_tracer_level", default=1,
|
||||
help="Profiler device tracer level", type=int)
|
||||
parser.add_argument("--python_tracer_level", default=1,
|
||||
help="Profiler Python tracer level", type=int)
|
||||
|
||||
def collect_profile(port: int, duration_in_ms: int, host: str,
|
||||
log_dir: Optional[str], host_tracer_level: int,
|
||||
device_tracer_level: int, python_tracer_level: int,
|
||||
no_perfetto_link: bool):
|
||||
options = profiler.ProfilerOptions(
|
||||
host_tracer_level=host_tracer_level,
|
||||
device_tracer_level=device_tracer_level,
|
||||
python_tracer_level=python_tracer_level,
|
||||
)
|
||||
log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp())
|
||||
profiler_client.trace(
|
||||
f"{host}:{port}",
|
||||
str(log_dir_),
|
||||
duration_in_ms,
|
||||
options=options)
|
||||
print(f"Dumped profiling information in: {log_dir_}")
|
||||
# The profiler dumps `xplane.pb` to the logging directory. To upload it to
|
||||
# the Perfetto trace viewer, we need to convert it to a `trace.json` file.
|
||||
# We do this by first finding the `xplane.pb` file, then passing it into
|
||||
# tensorflow_profile_plugin's `xplane` conversion function.
|
||||
curr_path = log_dir_.resolve()
|
||||
root_trace_folder = curr_path / "plugins" / "profile"
|
||||
trace_folders = [root_trace_folder / trace_folder for trace_folder
|
||||
in root_trace_folder.iterdir()]
|
||||
latest_folder = max(trace_folders, key=os.path.getmtime)
|
||||
xplane = next(latest_folder.glob("*.xplane.pb"))
|
||||
result = convert.xspace_to_tool_data([xplane], "trace_viewer^", None)
|
||||
|
||||
with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp:
|
||||
fp.write(result.encode("utf-8"))
|
||||
|
||||
if not no_perfetto_link:
|
||||
jax._src.profiler._host_perfetto_trace_file(str(log_dir_))
|
||||
|
||||
def main(args):
|
||||
collect_profile(args.port, args.duration_in_ms, args.host, args.log_dir,
|
||||
args.host_tracer_level, args.device_tracer_level,
|
||||
args.python_tracer_level, args.no_perfetto_link)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parser.parse_args())
|
@ -18,6 +18,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
@ -39,6 +40,14 @@ except ImportError:
|
||||
profiler_client = None
|
||||
tf_profiler = None
|
||||
|
||||
TBP_ENABLED = False
|
||||
try:
|
||||
import tensorboard_plugin_profile
|
||||
del tensorboard_plugin_profile
|
||||
TBP_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@ -205,5 +214,32 @@ class ProfilerTest(unittest.TestCase):
|
||||
thread_worker.join(120)
|
||||
self._check_xspace_pb_exist(logdir)
|
||||
|
||||
@unittest.skipIf(
|
||||
not (portpicker and profiler_client and tf_profiler and TBP_ENABLED),
|
||||
"Test requires tensorflow.profiler, portpicker and "
|
||||
"tensorboard_profile_plugin")
|
||||
def test_remote_profiler(self):
|
||||
port = portpicker.pick_unused_port()
|
||||
|
||||
logdir = absltest.get_default_test_tmpdir()
|
||||
# Remove any existing log files.
|
||||
shutil.rmtree(logdir, ignore_errors=True)
|
||||
def on_profile():
|
||||
os.system(
|
||||
f"python -m jax.collect_profile {port} 500 --log_dir {logdir} "
|
||||
"--no_perfetto_link")
|
||||
|
||||
thread_profiler = threading.Thread(
|
||||
target=on_profile, args=())
|
||||
thread_profiler.start()
|
||||
jax.profiler.start_server(port)
|
||||
start_time = time.time()
|
||||
y = jnp.zeros((5, 5))
|
||||
while time.time() - start_time < 3:
|
||||
y = jnp.dot(y, y)
|
||||
jax.profiler.stop_server()
|
||||
thread_profiler.join()
|
||||
self._check_xspace_pb_exist(logdir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user