Roy Frostig
d927a5dbf3
migrate internal dependencies from jax.core
to jax._src.core
...
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
2a422c7203
Fix or ignore some pytype errors.
...
PiperOrigin-RevId: 452208582
2022-05-31 21:25:55 -07:00
Matthew Johnson
9cd55a2bbd
[remove-units] remove units
2022-05-04 10:58:56 -07:00
Peter Hawkins
c978df5550
Increase minimum jaxlib version to 0.3.0.
2022-03-04 10:33:03 -05:00
Peter Hawkins
82d8261308
Speed up source location computation when lowering a jaxpr to HLO/MHLO.
...
Speed up source_info_util.user_frames by using a newly refactored Traceback.raw_frames() attribute. Since we are interested only in one frame, it's best to avoid doing wasted work on all the frames we are going to ignore.
Change traceback.raw_frames() to return the transpose of what it previously returned because it means we only need to build 3 Python objects, rather than n + 1 Python objects for n frames.
PiperOrigin-RevId: 427320674
2022-02-08 16:17:40 -08:00
Peter Hawkins
dbf03d2ee6
[JAX] Add support for generating "equation profiles" in JAX.
...
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.
For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
primitive: Total 6062.0
1509.0 (24.89%): mul
936.0 (15.44%): add
589.0 ( 9.72%): reshape
492.0 ( 8.12%): div
485.0 ( 8.00%): broadcast_in_dim
330.0 ( 5.44%): reduce_sum
322.0 ( 5.31%): integer_pow
230.0 ( 3.79%): add_any
174.0 ( 2.87%): convert_element_type
160.0 ( 2.64%): select
158.0 ( 2.61%): conv_general_dilated
116.0 ( 1.91%): sub
110.0 ( 1.81%): eq
110.0 ( 1.81%): neg
104.0 ( 1.72%): max
53.0 ( 0.87%): rsqrt
52.0 ( 0.86%): rev
49.0 ( 0.81%): custom_jvp_call_jaxpr
49.0 ( 0.81%): gt
5.0 (0.082%): xla_call
4.0 (0.066%): min
3.0 (0.049%): dot_general
3.0 (0.049%): lt
2.0 (0.033%): cos
2.0 (0.033%): exp
2.0 (0.033%): iota
2.0 (0.033%): log
2.0 (0.033%): psum
2.0 (0.033%): reduce_max
2.0 (0.033%): stop_gradient
1.0 (0.016%): argmax
1.0 (0.016%): reduce_window_max
1.0 (0.016%): select_and_scatter_add
1.0 (0.016%): transpose
1.0 (0.016%): xla_pmap
```
Or the lines of code that generated the most equations:
```
$ pprof --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
flat flat% sum% cum cum%
1537 25.35% 25.35% 1537 25.35% _compute_stats
1484 24.48% 49.84% 1484 24.48% _normalize
849 14.01% 63.84% 6062 100% __call__
644 10.62% 74.46% 644 10.62% <unknown>
483 7.97% 82.43% 483 7.97% <unknown>
392 6.47% 88.90% 6061 100% train_step
324 5.34% 94.24% 324 5.34% <unknown>
161 2.66% 96.90% 161 2.66% <unknown>
57 0.94% 97.84% 4292 70.80% loss_fn
52 0.86% 98.70% 52 0.86% schedule
39 0.64% 99.34% 39 0.64% softmax_cross_entropy
8 0.13% 99.47% 30 0.49% compute_metrics
6 0.099% 99.57% 61 1.01% cross_entropy_loss
1 0.016% 99.59% 1321 21.79% apply_gradients
1 0.016% 99.60% 6062 100% train_and_evaluate
0 0% 99.60% 6062 100% <unknown>
0 0% 99.60% 6062 100% __init__
0 0% 99.60% 3872 63.87% _call_wrapped_method
0 0% 99.60% 6062 100% _run_and_get_tests_result
0 0% 99.60% 6062 100% _run_code_in_main
0 0% 99.60% 6062 100% _run_in_app
0 0% 99.60% 6062 100% _run_main
0 0% 99.60% 3872 63.87% apply
0 0% 99.60% 161 2.66% apply_updates
0 0% 99.60% 6062 100% main
0 0% 99.60% 6062 100% main_function
0 0% 99.60% 6062 100% run
0 0% 99.60% 6062 100% runTests
0 0% 99.60% 6062 100% run_filename_as_main
0 0% 99.60% 6062 100% run_tests
0 0% 99.60% 3872 63.87% scope_fn
0 0% 99.60% 6062 100% test_train_and_evaluate
0 0% 99.60% 1159 19.12% update_fn
0 0% 99.60% 3872 63.87% wrapped_fn
0 0% 99.60% 3872 63.87% wrapped_module_method
0 0% 99.60% 3872 63.87% wrapper
```
I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```
[XLA:Python] Add helpers to Traceback and for working with pprof profiles.
* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.
PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Peter Hawkins
48bbdbc890
Change jax.core.DropVar to be a non-singleton.
...
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.
Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.
PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
3ac809ede3
[JAX] Move jax.util to jax._src_util.
...
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Peter Hawkins
7efc1dbc94
[JAX] Move source_info_util into jax._src.
...
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.
PiperOrigin-RevId: 340698612
2020-11-04 11:54:24 -08:00
Roy Frostig
d778a6d074
move experimental.jaxpr_stats to jaxpr_util
2020-08-18 18:07:38 -07:00