mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add programmatic profiling APIs, and rename some existing APIs.
This change provides aliases for the renamed APIs so existing code won't break. We should remove these aliases after the next release.
This commit is contained in:
parent
b8812b2a5d
commit
b68a08adf1
10
CHANGELOG.md
10
CHANGELOG.md
@ -10,9 +10,17 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.2.12 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master).
|
||||
|
||||
* New features
|
||||
* New profiling APIs: {func}`jax.profiler.start_trace`,
|
||||
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
|
||||
* Breaking changes:
|
||||
* The minimum jaxlib version is now 0.1.64.
|
||||
* Some profiler APIs names have been changed. There are still aliases, so this
|
||||
should not break existing code, but the aliases will eventually be removed
|
||||
so please change your code.
|
||||
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
|
||||
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
|
||||
* `trace_function` --> {func}`~jax.profiler.annotate_function`
|
||||
|
||||
|
||||
## jaxlib 0.1.65 (unreleased)
|
||||
|
@ -15,8 +15,12 @@ features.
|
||||
:toctree: _autosummary
|
||||
|
||||
start_server
|
||||
trace_function
|
||||
TraceContext
|
||||
start_trace
|
||||
stop_trace
|
||||
trace
|
||||
annotate_function
|
||||
TraceAnnotation
|
||||
StepTraceAnnotation
|
||||
|
||||
|
||||
Device memory profiling
|
||||
@ -30,4 +34,3 @@ profiling features.
|
||||
|
||||
device_memory_profile
|
||||
save_device_memory_profile
|
||||
|
@ -18,7 +18,69 @@ Install specific nightly versions of TensorBoard, TensorBoard profiler, TensorFl
|
||||
pip install --upgrade tb-nightly==2.5.0a20201203 tbp-nightly==2.4.0a20201203 tf-nightly==2.5.0.dev20201203 tensorboard-plugin-wit==1.7.0
|
||||
```
|
||||
|
||||
### Usage
|
||||
### Programmatic capture
|
||||
|
||||
You can instrument your code to capture a profiler trace via the
|
||||
{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace`
|
||||
methods. Call {func}`~jax.profiler.start_trace` with the directory to write
|
||||
trace files to. This should be the same `--logdir` directory used to start
|
||||
TensorBoard. Then, you can use TensorBoard to view the traces.
|
||||
|
||||
For example, to take a profiler trace:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.profiler.start_trace("/tmp/tensorboard")
|
||||
|
||||
# Run the operations to be profiled
|
||||
key = jax.random.PRNGKey(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
|
||||
jax.profiler.stop_trace()
|
||||
```
|
||||
|
||||
Note the {func}`block_until_ready` call. We use this to make sure on-device
|
||||
execution is captured by the trace. See {ref}`async-dispatch` for details on why
|
||||
this is necessary.
|
||||
|
||||
You can also use the {func}`jax.profiler.trace` context manager as an
|
||||
alternative to `start_trace` and `stop_trace`:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
with jax.profiler.trace():
|
||||
key = jax.random.PRNGKey(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
```
|
||||
|
||||
To view the trace, first start TensorBoard if you haven't already:
|
||||
|
||||
```shell
|
||||
$ tensorboard --logdir /tmp/tensorboard
|
||||
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
|
||||
TensorBoard 2.5.0a20210304 at http://localhost:6006/ (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
You should be able to load TensorBoard at <http://localhost:6006/> in this
|
||||
example. You can specify a different port with the `--port` flag. See
|
||||
{ref}`remote_profiling` below if running JAX on a remote server.
|
||||
|
||||
Then, either select "Profile" in the upper-right dropdown menu, or go directly
|
||||
to <http://localhost:6006/#profile>. Available traces appear in the "Runs"
|
||||
dropdown menu on the left. Select the run you're interested in, and then under
|
||||
"Tools", select "trace_viewer". You should now see a timeline of the
|
||||
execution. You can use the WASD keys to navigate the trace, and click or drag to
|
||||
select events to see more details at the bottom. See [these TensorFlow
|
||||
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
|
||||
for more details on using the trace viewer.
|
||||
|
||||
### Manual capture via TensorBoard
|
||||
|
||||
The following are instructions for capturing a manually-triggered N-second trace
|
||||
from a running program.
|
||||
@ -76,10 +138,12 @@ from a running program.
|
||||
docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)
|
||||
for more details on using the trace viewer.<br /><br />
|
||||
|
||||
1. By default, the events in the trace viewer are mostly low-level internal JAX
|
||||
functions. You can add your own events and functions by using
|
||||
{func}`jax.profiler.TraceContext` and {func}`jax.profiler.trace_function` in
|
||||
your code and capturing a new profile.
|
||||
### Adding custom trace events
|
||||
|
||||
By default, the events in the trace viewer are mostly low-level internal JAX
|
||||
functions. You can add your own events and functions by using
|
||||
{class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in
|
||||
your code.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
|
@ -12,8 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
import warnings
|
||||
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
@ -24,8 +27,8 @@ def start_server(port: int):
|
||||
|
||||
Using the "TensorFlow profiler" feature in `TensorBoard
|
||||
<https://www.tensorflow.org/tensorboard>`_ 2.2 or newer, you can
|
||||
connect to the profiler server and sample execution traces that show CPU and
|
||||
GPU device activity.
|
||||
connect to the profiler server and sample execution traces that show CPU,
|
||||
GPU, and/or TPU device activity.
|
||||
|
||||
Returns a profiler server object. The server remains alive and listening until
|
||||
the server object is destroyed.
|
||||
@ -33,8 +36,80 @@ def start_server(port: int):
|
||||
return xla_client.profiler.start_server(port)
|
||||
|
||||
|
||||
class TraceContext(xla_client.profiler.TraceMe):
|
||||
"""Context manager generates a trace event in the profiler.
|
||||
class _ProfileState(object):
|
||||
def __init__(self):
|
||||
self.profile_session = None
|
||||
self.log_dir = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
_profile_state = _ProfileState()
|
||||
|
||||
|
||||
def start_trace(log_dir):
|
||||
"""Starts a profiler trace.
|
||||
|
||||
The trace will capture CPU, GPU, and/or TPU activity, including Python
|
||||
functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
|
||||
and save the results to ``log_dir``.
|
||||
|
||||
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
|
||||
doesn't need to be running when collecting the trace.
|
||||
|
||||
Only once trace may be collected a time. A RuntimeError will be raised if
|
||||
``start_trace()`` is called while another trace is running.
|
||||
|
||||
Args:
|
||||
log_dir: The directory to save the profiler trace to (usually the
|
||||
TensorBoard log directory).
|
||||
"""
|
||||
with _profile_state.lock:
|
||||
if _profile_state.profile_session is not None:
|
||||
raise RuntimeError("Profile has already been started. "
|
||||
"Only one profile may be run at a time.")
|
||||
_profile_state.profile_session = xla_client.profiler.ProfilerSession()
|
||||
_profile_state.log_dir = log_dir
|
||||
|
||||
|
||||
def stop_trace():
|
||||
"""Stops the currently-running profiler trace.
|
||||
|
||||
The trace will be saved to the ``log_dir`` passed to the corresponding
|
||||
``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started.
|
||||
"""
|
||||
with _profile_state.lock:
|
||||
if _profile_state.profile_session is None:
|
||||
raise RuntimeError("No profile started")
|
||||
_profile_state.profile_session.stop_and_export(_profile_state.log_dir)
|
||||
_profile_state.profile_session = None
|
||||
_profile_state.log_dir = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace(log_dir):
|
||||
"""Context manager to take a profiler trace.
|
||||
|
||||
The trace will capture CPU, GPU, and/or TPU activity, including Python
|
||||
functions and JAX on-device operations.
|
||||
|
||||
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
|
||||
doesn't need to be running when collecting the trace.
|
||||
|
||||
Only once trace may be collected a time. A RuntimeError will be raised if a
|
||||
trace is started while another trace is running.
|
||||
|
||||
Args:
|
||||
log_dir: The directory to save the profiler trace to (usually the
|
||||
TensorBoard log directory).
|
||||
"""
|
||||
start_trace(log_dir)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stop_trace()
|
||||
|
||||
|
||||
class TraceAnnotation(xla_client.profiler.TraceMe):
|
||||
"""Context manager that generates a trace event in the profiler.
|
||||
|
||||
The trace event spans the duration of the code enclosed by the context.
|
||||
|
||||
@ -42,16 +117,25 @@ class TraceContext(xla_client.profiler.TraceMe):
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> x = jnp.ones((1000, 1000))
|
||||
>>> with jax.profiler.TraceContext("acontext"):
|
||||
>>> with jax.profiler.TraceAnnotation("my_label"):
|
||||
... jnp.dot(x, x.T).block_until_ready()
|
||||
|
||||
This will cause an "acontext" event to show up on the trace timeline if the
|
||||
event occurs while the process is being traced by TensorBoard.
|
||||
This will cause a "my_label" event to show up on the trace timeline if the
|
||||
event occurs while the process is being traced.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StepTraceContext(TraceContext):
|
||||
# TODO: remove this sometime after jax 0.1.11 is released
|
||||
class TraceContext(TraceAnnotation):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"TraceContext has been renamed to TraceAnnotation. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class StepTraceAnnotation(TraceAnnotation):
|
||||
"""Context manager that generates a step trace event in the profiler.
|
||||
|
||||
The step trace event spans the duration of the code enclosed by the context.
|
||||
@ -63,7 +147,7 @@ class StepTraceContext(TraceContext):
|
||||
>>> import jax
|
||||
>>>
|
||||
>>> while global_step < NUM_STEPS:
|
||||
... with jax.profiler.StepTraceContext("train", step_num=global_step):
|
||||
... with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
|
||||
... train_step()
|
||||
... global_step += 1
|
||||
|
||||
@ -79,14 +163,23 @@ class StepTraceContext(TraceContext):
|
||||
super().__init__(name, _r=1, **kwargs)
|
||||
|
||||
|
||||
def trace_function(func: Callable, name: str = None, **kwargs):
|
||||
# TODO: remove this sometime after jax 0.1.11 is released
|
||||
class StepTraceContext(StepTraceAnnotation):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"StepTraceContext has been renamed to StepTraceAnnotation. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def annotate_function(func: Callable, name: str = None, **kwargs):
|
||||
"""Decorator that generates a trace event for the execution of a function.
|
||||
|
||||
For example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>>
|
||||
>>> @jax.profiler.trace_function
|
||||
>>> @jax.profiler.annotate_function
|
||||
>>> def f(x):
|
||||
... return jnp.dot(x, x.T).block_until_ready()
|
||||
>>>
|
||||
@ -111,12 +204,21 @@ def trace_function(func: Callable, name: str = None, **kwargs):
|
||||
name = name or func.__name__
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with TraceContext(name, **kwargs):
|
||||
with TraceAnnotation(name, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.1.11 is released
|
||||
def trace_function(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"trace_function has been renamed to annotate_function. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return annotate_function(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
def device_memory_profile(backend: Optional[str] = None) -> bytes:
|
||||
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.
|
||||
|
||||
|
@ -14,10 +14,16 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.profiler import (
|
||||
StepTraceAnnotation,
|
||||
StepTraceContext,
|
||||
TraceAnnotation,
|
||||
TraceContext,
|
||||
device_memory_profile,
|
||||
save_device_memory_profile,
|
||||
start_server,
|
||||
start_trace,
|
||||
stop_trace,
|
||||
trace,
|
||||
annotate_function,
|
||||
trace_function,
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ from functools import partial
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
@ -56,23 +57,69 @@ class ProfilerTest(unittest.TestCase):
|
||||
jax.profiler.start_server(port=port)
|
||||
del port
|
||||
|
||||
def testTraceContext(self):
|
||||
def testProgrammaticProfiling(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
jax.profiler.start_trace(tmpdir)
|
||||
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
|
||||
jnp.ones(jax.local_device_count()))
|
||||
jax.profiler.stop_trace()
|
||||
|
||||
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
|
||||
recursive=True)
|
||||
self.assertEqual(len(proto_path), 1)
|
||||
with open(proto_path[0], "rb") as f:
|
||||
proto = f.read()
|
||||
# Sanity check that serialized proto contains host and device traces
|
||||
# without deserializing.
|
||||
self.assertIn(b"/host:CPU", proto)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.assertIn(b"/device:TPU", proto)
|
||||
|
||||
def testProgrammaticProfilingErrors(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "No profile started"):
|
||||
jax.profiler.stop_trace()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
jax.profiler.start_trace(tmpdir)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Profile has already been started. Only one "
|
||||
"profile may be run at a time."):
|
||||
jax.profiler.start_trace(tmpdir)
|
||||
|
||||
def testProgrammaticProfilingContextManager(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with jax.profiler.trace(tmpdir):
|
||||
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
|
||||
jnp.ones(jax.local_device_count()))
|
||||
|
||||
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
|
||||
recursive=True)
|
||||
self.assertEqual(len(proto_path), 1)
|
||||
with open(proto_path[0], "rb") as f:
|
||||
proto = f.read()
|
||||
# Sanity check that serialized proto contains host and device traces
|
||||
# without deserializing.
|
||||
self.assertIn(b"/host:CPU", proto)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.assertIn(b"/device:TPU", proto)
|
||||
|
||||
def testTraceAnnotation(self):
|
||||
x = 3
|
||||
with jax.profiler.TraceContext("mycontext"):
|
||||
with jax.profiler.TraceAnnotation("mycontext"):
|
||||
x = x + 2
|
||||
|
||||
def testTraceFunction(self):
|
||||
@jax.profiler.trace_function
|
||||
@jax.profiler.annotate_function
|
||||
def f(x):
|
||||
return x + 2
|
||||
self.assertEqual(f(7), 9)
|
||||
|
||||
@partial(jax.profiler.trace_function, name="aname")
|
||||
@partial(jax.profiler.annotate_function, name="aname")
|
||||
def g(x):
|
||||
return x + 2
|
||||
self.assertEqual(g(7), 9)
|
||||
|
||||
@partial(jax.profiler.trace_function, name="aname", akwarg="hello")
|
||||
@partial(jax.profiler.annotate_function, name="aname", akwarg="hello")
|
||||
def h(x):
|
||||
return x + 2
|
||||
self.assertEqual(h(7), 9)
|
||||
@ -96,7 +143,7 @@ class ProfilerTest(unittest.TestCase):
|
||||
worker_start.set()
|
||||
x = jnp.ones((1000, 1000))
|
||||
while True:
|
||||
with jax.profiler.TraceContext("atracecontext"):
|
||||
with jax.profiler.TraceAnnotation("atraceannotation"):
|
||||
jnp.dot(x, x.T).block_until_ready()
|
||||
if self.profile_done:
|
||||
break
|
||||
|
Loading…
x
Reference in New Issue
Block a user