Peter Hawkins
|
dbf03d2ee6
|
[JAX] Add support for generating "equation profiles" in JAX.
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.
For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
primitive: Total 6062.0
1509.0 (24.89%): mul
936.0 (15.44%): add
589.0 ( 9.72%): reshape
492.0 ( 8.12%): div
485.0 ( 8.00%): broadcast_in_dim
330.0 ( 5.44%): reduce_sum
322.0 ( 5.31%): integer_pow
230.0 ( 3.79%): add_any
174.0 ( 2.87%): convert_element_type
160.0 ( 2.64%): select
158.0 ( 2.61%): conv_general_dilated
116.0 ( 1.91%): sub
110.0 ( 1.81%): eq
110.0 ( 1.81%): neg
104.0 ( 1.72%): max
53.0 ( 0.87%): rsqrt
52.0 ( 0.86%): rev
49.0 ( 0.81%): custom_jvp_call_jaxpr
49.0 ( 0.81%): gt
5.0 (0.082%): xla_call
4.0 (0.066%): min
3.0 (0.049%): dot_general
3.0 (0.049%): lt
2.0 (0.033%): cos
2.0 (0.033%): exp
2.0 (0.033%): iota
2.0 (0.033%): log
2.0 (0.033%): psum
2.0 (0.033%): reduce_max
2.0 (0.033%): stop_gradient
1.0 (0.016%): argmax
1.0 (0.016%): reduce_window_max
1.0 (0.016%): select_and_scatter_add
1.0 (0.016%): transpose
1.0 (0.016%): xla_pmap
```
Or the lines of code that generated the most equations:
```
$ pprof --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
flat flat% sum% cum cum%
1537 25.35% 25.35% 1537 25.35% _compute_stats
1484 24.48% 49.84% 1484 24.48% _normalize
849 14.01% 63.84% 6062 100% __call__
644 10.62% 74.46% 644 10.62% <unknown>
483 7.97% 82.43% 483 7.97% <unknown>
392 6.47% 88.90% 6061 100% train_step
324 5.34% 94.24% 324 5.34% <unknown>
161 2.66% 96.90% 161 2.66% <unknown>
57 0.94% 97.84% 4292 70.80% loss_fn
52 0.86% 98.70% 52 0.86% schedule
39 0.64% 99.34% 39 0.64% softmax_cross_entropy
8 0.13% 99.47% 30 0.49% compute_metrics
6 0.099% 99.57% 61 1.01% cross_entropy_loss
1 0.016% 99.59% 1321 21.79% apply_gradients
1 0.016% 99.60% 6062 100% train_and_evaluate
0 0% 99.60% 6062 100% <unknown>
0 0% 99.60% 6062 100% __init__
0 0% 99.60% 3872 63.87% _call_wrapped_method
0 0% 99.60% 6062 100% _run_and_get_tests_result
0 0% 99.60% 6062 100% _run_code_in_main
0 0% 99.60% 6062 100% _run_in_app
0 0% 99.60% 6062 100% _run_main
0 0% 99.60% 3872 63.87% apply
0 0% 99.60% 161 2.66% apply_updates
0 0% 99.60% 6062 100% main
0 0% 99.60% 6062 100% main_function
0 0% 99.60% 6062 100% run
0 0% 99.60% 6062 100% runTests
0 0% 99.60% 6062 100% run_filename_as_main
0 0% 99.60% 6062 100% run_tests
0 0% 99.60% 3872 63.87% scope_fn
0 0% 99.60% 6062 100% test_train_and_evaluate
0 0% 99.60% 1159 19.12% update_fn
0 0% 99.60% 3872 63.87% wrapped_fn
0 0% 99.60% 3872 63.87% wrapped_module_method
0 0% 99.60% 3872 63.87% wrapper
```
I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```
[XLA:Python] Add helpers to Traceback and for working with pprof profiles.
* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.
PiperOrigin-RevId: 421395346
|
2022-01-12 14:27:57 -08:00 |
|