11 Commits

Author SHA1 Message Date
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
2a422c7203 Fix or ignore some pytype errors.
PiperOrigin-RevId: 452208582
2022-05-31 21:25:55 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Peter Hawkins
82d8261308 Speed up source location computation when lowering a jaxpr to HLO/MHLO.
Speed up source_info_util.user_frames by using a newly refactored Traceback.raw_frames() attribute. Since we are interested only in one frame, it's best to avoid doing wasted work on all the frames we are going to ignore.

Change traceback.raw_frames() to return the transpose of what it previously returned because it means we only need to build 3 Python objects, rather than n + 1 Python objects for n frames.

PiperOrigin-RevId: 427320674
2022-02-08 16:17:40 -08:00
Peter Hawkins
dbf03d2ee6 [JAX] Add support for generating "equation profiles" in JAX.
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.

For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
 primitive: Total 6062.0
            1509.0 (24.89%): mul
             936.0 (15.44%): add
             589.0 ( 9.72%): reshape
             492.0 ( 8.12%): div
             485.0 ( 8.00%): broadcast_in_dim
             330.0 ( 5.44%): reduce_sum
             322.0 ( 5.31%): integer_pow
             230.0 ( 3.79%): add_any
             174.0 ( 2.87%): convert_element_type
             160.0 ( 2.64%): select
             158.0 ( 2.61%): conv_general_dilated
             116.0 ( 1.91%): sub
             110.0 ( 1.81%): eq
             110.0 ( 1.81%): neg
             104.0 ( 1.72%): max
              53.0 ( 0.87%): rsqrt
              52.0 ( 0.86%): rev
              49.0 ( 0.81%): custom_jvp_call_jaxpr
              49.0 ( 0.81%): gt
               5.0 (0.082%): xla_call
               4.0 (0.066%): min
               3.0 (0.049%): dot_general
               3.0 (0.049%): lt
               2.0 (0.033%): cos
               2.0 (0.033%): exp
               2.0 (0.033%): iota
               2.0 (0.033%): log
               2.0 (0.033%): psum
               2.0 (0.033%): reduce_max
               2.0 (0.033%): stop_gradient
               1.0 (0.016%): argmax
               1.0 (0.016%): reduce_window_max
               1.0 (0.016%): select_and_scatter_add
               1.0 (0.016%): transpose
               1.0 (0.016%): xla_pmap
```

Or the lines of code that generated the most equations:
```
$ pprof  --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
      flat  flat%   sum%        cum   cum%
      1537 25.35% 25.35%       1537 25.35%  _compute_stats
      1484 24.48% 49.84%       1484 24.48%  _normalize
       849 14.01% 63.84%       6062   100%  __call__
       644 10.62% 74.46%        644 10.62%  <unknown>
       483  7.97% 82.43%        483  7.97%  <unknown>
       392  6.47% 88.90%       6061   100%  train_step
       324  5.34% 94.24%        324  5.34%  <unknown>
       161  2.66% 96.90%        161  2.66%  <unknown>
        57  0.94% 97.84%       4292 70.80%  loss_fn
        52  0.86% 98.70%         52  0.86%  schedule
        39  0.64% 99.34%         39  0.64%  softmax_cross_entropy
         8  0.13% 99.47%         30  0.49%  compute_metrics
         6 0.099% 99.57%         61  1.01%  cross_entropy_loss
         1 0.016% 99.59%       1321 21.79%  apply_gradients
         1 0.016% 99.60%       6062   100%  train_and_evaluate
         0     0% 99.60%       6062   100%  <unknown>
         0     0% 99.60%       6062   100%  __init__
         0     0% 99.60%       3872 63.87%  _call_wrapped_method
         0     0% 99.60%       6062   100%  _run_and_get_tests_result
         0     0% 99.60%       6062   100%  _run_code_in_main
         0     0% 99.60%       6062   100%  _run_in_app
         0     0% 99.60%       6062   100%  _run_main
         0     0% 99.60%       3872 63.87%  apply
         0     0% 99.60%        161  2.66%  apply_updates
         0     0% 99.60%       6062   100%  main
         0     0% 99.60%       6062   100%  main_function
         0     0% 99.60%       6062   100%  run
         0     0% 99.60%       6062   100%  runTests
         0     0% 99.60%       6062   100%  run_filename_as_main
         0     0% 99.60%       6062   100%  run_tests
         0     0% 99.60%       3872 63.87%  scope_fn
         0     0% 99.60%       6062   100%  test_train_and_evaluate
         0     0% 99.60%       1159 19.12%  update_fn
         0     0% 99.60%       3872 63.87%  wrapped_fn
         0     0% 99.60%       3872 63.87%  wrapped_module_method
         0     0% 99.60%       3872 63.87%  wrapper
```

I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```

[XLA:Python] Add helpers to Traceback and for working with pprof profiles.

* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.

PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Peter Hawkins
7efc1dbc94 [JAX] Move source_info_util into jax._src.
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.

PiperOrigin-RevId: 340698612
2020-11-04 11:54:24 -08:00
Roy Frostig
d778a6d074 move experimental.jaxpr_stats to jaxpr_util 2020-08-18 18:07:38 -07:00