16 Commits

Author SHA1 Message Date
Matthew Johnson
0a693faf48 add pjit forwarding rule
Co-authored-by: Roy Frostig <frostig@google.com>
2024-05-25 17:46:01 +00:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
afd56c15d9 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Yash Katariya
849af498d1 Make jaxpr_util_test work with jit/pjit merge
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Yuxin Wu
96f6c1c9d4 Let is_user_frame ignore frames from stdlib.
When using decorators, we found contextlib.py from stdlib sometimes become the most recent non-jax frame. But it's not a user frame.

PiperOrigin-RevId: 486993924
2022-11-08 10:50:08 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Peter Hawkins
dbf03d2ee6 [JAX] Add support for generating "equation profiles" in JAX.
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.

For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
 primitive: Total 6062.0
            1509.0 (24.89%): mul
             936.0 (15.44%): add
             589.0 ( 9.72%): reshape
             492.0 ( 8.12%): div
             485.0 ( 8.00%): broadcast_in_dim
             330.0 ( 5.44%): reduce_sum
             322.0 ( 5.31%): integer_pow
             230.0 ( 3.79%): add_any
             174.0 ( 2.87%): convert_element_type
             160.0 ( 2.64%): select
             158.0 ( 2.61%): conv_general_dilated
             116.0 ( 1.91%): sub
             110.0 ( 1.81%): eq
             110.0 ( 1.81%): neg
             104.0 ( 1.72%): max
              53.0 ( 0.87%): rsqrt
              52.0 ( 0.86%): rev
              49.0 ( 0.81%): custom_jvp_call_jaxpr
              49.0 ( 0.81%): gt
               5.0 (0.082%): xla_call
               4.0 (0.066%): min
               3.0 (0.049%): dot_general
               3.0 (0.049%): lt
               2.0 (0.033%): cos
               2.0 (0.033%): exp
               2.0 (0.033%): iota
               2.0 (0.033%): log
               2.0 (0.033%): psum
               2.0 (0.033%): reduce_max
               2.0 (0.033%): stop_gradient
               1.0 (0.016%): argmax
               1.0 (0.016%): reduce_window_max
               1.0 (0.016%): select_and_scatter_add
               1.0 (0.016%): transpose
               1.0 (0.016%): xla_pmap
```

Or the lines of code that generated the most equations:
```
$ pprof  --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
      flat  flat%   sum%        cum   cum%
      1537 25.35% 25.35%       1537 25.35%  _compute_stats
      1484 24.48% 49.84%       1484 24.48%  _normalize
       849 14.01% 63.84%       6062   100%  __call__
       644 10.62% 74.46%        644 10.62%  <unknown>
       483  7.97% 82.43%        483  7.97%  <unknown>
       392  6.47% 88.90%       6061   100%  train_step
       324  5.34% 94.24%        324  5.34%  <unknown>
       161  2.66% 96.90%        161  2.66%  <unknown>
        57  0.94% 97.84%       4292 70.80%  loss_fn
        52  0.86% 98.70%         52  0.86%  schedule
        39  0.64% 99.34%         39  0.64%  softmax_cross_entropy
         8  0.13% 99.47%         30  0.49%  compute_metrics
         6 0.099% 99.57%         61  1.01%  cross_entropy_loss
         1 0.016% 99.59%       1321 21.79%  apply_gradients
         1 0.016% 99.60%       6062   100%  train_and_evaluate
         0     0% 99.60%       6062   100%  <unknown>
         0     0% 99.60%       6062   100%  __init__
         0     0% 99.60%       3872 63.87%  _call_wrapped_method
         0     0% 99.60%       6062   100%  _run_and_get_tests_result
         0     0% 99.60%       6062   100%  _run_code_in_main
         0     0% 99.60%       6062   100%  _run_in_app
         0     0% 99.60%       6062   100%  _run_main
         0     0% 99.60%       3872 63.87%  apply
         0     0% 99.60%        161  2.66%  apply_updates
         0     0% 99.60%       6062   100%  main
         0     0% 99.60%       6062   100%  main_function
         0     0% 99.60%       6062   100%  run
         0     0% 99.60%       6062   100%  runTests
         0     0% 99.60%       6062   100%  run_filename_as_main
         0     0% 99.60%       6062   100%  run_tests
         0     0% 99.60%       3872 63.87%  scope_fn
         0     0% 99.60%       6062   100%  test_train_and_evaluate
         0     0% 99.60%       1159 19.12%  update_fn
         0     0% 99.60%       3872 63.87%  wrapped_fn
         0     0% 99.60%       3872 63.87%  wrapped_module_method
         0     0% 99.60%       3872 63.87%  wrapper
```

I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```

[XLA:Python] Add helpers to Traceback and for working with pprof profiles.

* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.

PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Roy Frostig
0daf4c00be assume less about source locations in jaxpr_util_test 2020-10-13 15:48:04 -07:00
Roy Frostig
5135fd176d fix jaxpr util test under enable_x64 2020-08-19 08:28:56 -07:00
Roy Frostig
d778a6d074 move experimental.jaxpr_stats to jaxpr_util 2020-08-18 18:07:38 -07:00