diff --git a/docs/_static/tensorboard_profiler.png b/docs/_static/tensorboard_profiler.png new file mode 100644 index 000000000..9b276b41d Binary files /dev/null and b/docs/_static/tensorboard_profiler.png differ diff --git a/docs/profiling.md b/docs/profiling.md index d5d0a4e28..637307024 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,12 +1,141 @@ # Profiling JAX programs -To profile JAX programs, there are currently two options: `nvprof` and XLA's -profiling features. +## TensorBoard profiling -## nvprof +[TensorBoard's +profiler](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras>) +can be used to profile JAX programs. Tensorboard is a great way to acquire and +visualize performance traces and profiles of your program, including activity on +GPU and TPU. The end result looks something like this: -Nvidia's `nvprof` tool can be used to trace and profile JAX code on GPU. For -details, see the `nvprof` documentation. +![TensorBoard profiler example](_static/tensorboard_profiler.png) + +### Installation + +```shell +# Requires TensorFlow and TensorBoard version >= 2.2 +pip install --upgrade tensorflow tensorboard_plugin_profile +``` + +### Usage + +The following are instructions for capturing a manually-triggered N-second trace +from a running program. + +1. Start a TensorBoard server: + + ```shell + tensorboard --logdir /tmp/tensorboard/ + ``` + + You should be able to load TensorBoard at . You can + specify a different port with the `--port` flag. See {ref}`remote_profiling` + below if running JAX on a remote server.

+ +1. In the Python program or process you'd like to profile, add the following + somewhere near the beginning: + + ```python + import jax.profiler + jax.profiler.start_server(9999) + ``` + + This starts the profiler server that TensorBoard connects to. The profiler + server must be running before you move on to the next step. + + If you'd like to profile a snippet of a long-running program (e.g. a long + training loop), you can put this at the beginning of the program and start + your program as usual. If you'd like to profile a short program (e.g. a + microbenchmark), one option is to start the profiler server in an IPython + shell, and run the short program with `%run` after starting the capture in + the next step. Another option is to start the profiler server at the + beginning of the program and use `time.sleep()` to give you enough time to + start the capture.

+ +1. Open , and click the "CAPTURE PROFILE" button + in the upper left. Enter "localhost:9999" as the profile service URL (this is + the address of the profiler server you started in the previous step). Enter + the number of milliseconds you'd like to profile for, and click "CAPTURE".

+ +1. If the code you'd like to profile isn't already running (e.g. if you started + the profiler server in a Python shell), run it while the capture is + running.

+ +1. After the capture finishes, TensorBoard should automatically refresh. (Not + all of the TensorBoard profiling features are hooked up with JAX, so it may + initially look like nothing was captured.) On the left under "Tools", select + "trace_viewer". + + You should now see a timeline of the execution. You can use the WASD keys to + navigate the trace, and click or drag to select events to see more details at + the bottom. See [these TensorFlow + docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) + for more details on using the trace viewer.

+ +1. By default, the events in the trace viewer are mostly low-level internal JAX + functions. You can add your own events and functions by using + {func}`jax.profiler.TraceContext` and {func}`jax.profiler.trace_function` in + your code and capturing a new profile. + +### Troubleshooting + +#### GPU profiling + +Programs running on GPU should produce traces for the GPU streams near the top +of the trace viewer. If you're only seeing the host traces, check your program +logs and/or output for the following error messages. + +**If you get an error like: `Could not load dynamic library 'libcupti.so.10.1'`**
+Full error: +``` +W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory +2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found. +``` + +Add the path to `libcupti.so` to the environment variable `LD_LIBRARY_PATH`. +(Try `locate libcupti.so` to find the path.) For example: +```shell +export LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH +``` + +**If you get an error like: `failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES`**
+Full error: +```shell +E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_->EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES +2020-06-12 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_->ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED +``` + +Run the following commands (note this requires a reboot): +```shell +echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"' | sudo tee -a /etc/modprobe.d/nvidia-kernel-common.conf +sudo update-initramfs -u +sudo reboot now +``` + +See [Nvidia's documentation on this +error](https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti) +for more information. + +(remote_profiling)= +#### Profiling on a remote machine + +If the JAX program you'd like to profile is running on a remote machine, one +option is to run all the instructions above on the remote machine (in +particular, start the TensorBoard server on the remote machine), then use SSH +local port forwarding to access the TensorBoard web UI from your local +machine. Use the following SSH command to forward the default TensorBoard port +6006 from the local to the remote machine: + +```shell +ssh -L 6006:localhost:6006 +``` + +## Nsight + +Nvidia's `Nsight` tools can be used to trace and profile JAX code on GPU. For +details, see the [`Nsight` +documentation](https://developer.nvidia.com/tools-overview). ## XLA profiling