Matthew Johnson
0a693faf48
add pjit forwarding rule
...
Co-authored-by: Roy Frostig <frostig@google.com>
2024-05-25 17:46:01 +00:00
Sergei Lebedev
cbcaac2756
MAINT Migrate remaining internal/test modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
2023-10-12 17:32:15 +01:00
Peter Hawkins
afd56c15d9
Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
...
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.
PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Yash Katariya
181355335c
Remove references to jax.config.jax_jit_pjit_api_merge
, which is always True at head.
...
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Yash Katariya
7b1128fdc4
Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
...
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Yash Katariya
849af498d1
Make jaxpr_util_test work with jit/pjit merge
...
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Yuxin Wu
96f6c1c9d4
Let is_user_frame ignore frames from stdlib.
...
When using decorators, we found contextlib.py from stdlib sometimes become the most recent non-jax frame. But it's not a user frame.
PiperOrigin-RevId: 486993924
2022-11-08 10:50:08 -08:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
c978df5550
Increase minimum jaxlib version to 0.3.0.
2022-03-04 10:33:03 -05:00
Peter Hawkins
dbf03d2ee6
[JAX] Add support for generating "equation profiles" in JAX.
...
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.
For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
primitive: Total 6062.0
1509.0 (24.89%): mul
936.0 (15.44%): add
589.0 ( 9.72%): reshape
492.0 ( 8.12%): div
485.0 ( 8.00%): broadcast_in_dim
330.0 ( 5.44%): reduce_sum
322.0 ( 5.31%): integer_pow
230.0 ( 3.79%): add_any
174.0 ( 2.87%): convert_element_type
160.0 ( 2.64%): select
158.0 ( 2.61%): conv_general_dilated
116.0 ( 1.91%): sub
110.0 ( 1.81%): eq
110.0 ( 1.81%): neg
104.0 ( 1.72%): max
53.0 ( 0.87%): rsqrt
52.0 ( 0.86%): rev
49.0 ( 0.81%): custom_jvp_call_jaxpr
49.0 ( 0.81%): gt
5.0 (0.082%): xla_call
4.0 (0.066%): min
3.0 (0.049%): dot_general
3.0 (0.049%): lt
2.0 (0.033%): cos
2.0 (0.033%): exp
2.0 (0.033%): iota
2.0 (0.033%): log
2.0 (0.033%): psum
2.0 (0.033%): reduce_max
2.0 (0.033%): stop_gradient
1.0 (0.016%): argmax
1.0 (0.016%): reduce_window_max
1.0 (0.016%): select_and_scatter_add
1.0 (0.016%): transpose
1.0 (0.016%): xla_pmap
```
Or the lines of code that generated the most equations:
```
$ pprof --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
flat flat% sum% cum cum%
1537 25.35% 25.35% 1537 25.35% _compute_stats
1484 24.48% 49.84% 1484 24.48% _normalize
849 14.01% 63.84% 6062 100% __call__
644 10.62% 74.46% 644 10.62% <unknown>
483 7.97% 82.43% 483 7.97% <unknown>
392 6.47% 88.90% 6061 100% train_step
324 5.34% 94.24% 324 5.34% <unknown>
161 2.66% 96.90% 161 2.66% <unknown>
57 0.94% 97.84% 4292 70.80% loss_fn
52 0.86% 98.70% 52 0.86% schedule
39 0.64% 99.34% 39 0.64% softmax_cross_entropy
8 0.13% 99.47% 30 0.49% compute_metrics
6 0.099% 99.57% 61 1.01% cross_entropy_loss
1 0.016% 99.59% 1321 21.79% apply_gradients
1 0.016% 99.60% 6062 100% train_and_evaluate
0 0% 99.60% 6062 100% <unknown>
0 0% 99.60% 6062 100% __init__
0 0% 99.60% 3872 63.87% _call_wrapped_method
0 0% 99.60% 6062 100% _run_and_get_tests_result
0 0% 99.60% 6062 100% _run_code_in_main
0 0% 99.60% 6062 100% _run_in_app
0 0% 99.60% 6062 100% _run_main
0 0% 99.60% 3872 63.87% apply
0 0% 99.60% 161 2.66% apply_updates
0 0% 99.60% 6062 100% main
0 0% 99.60% 6062 100% main_function
0 0% 99.60% 6062 100% run
0 0% 99.60% 6062 100% runTests
0 0% 99.60% 6062 100% run_filename_as_main
0 0% 99.60% 6062 100% run_tests
0 0% 99.60% 3872 63.87% scope_fn
0 0% 99.60% 6062 100% test_train_and_evaluate
0 0% 99.60% 1159 19.12% update_fn
0 0% 99.60% 3872 63.87% wrapped_fn
0 0% 99.60% 3872 63.87% wrapped_module_method
0 0% 99.60% 3872 63.87% wrapper
```
I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```
[XLA:Python] Add helpers to Traceback and for working with pprof profiles.
* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.
PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Peter Hawkins
db2e91eba2
Move jax.test_util to jax._src.test_util.
...
Add forwarding shims for names used by external clients of JAX in practice.
PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
2fd682ef2a
Make jax_enable_x64 a thread-local value.
2021-02-04 09:48:22 -08:00
Roy Frostig
0daf4c00be
assume less about source locations in jaxpr_util_test
2020-10-13 15:48:04 -07:00
Roy Frostig
5135fd176d
fix jaxpr util test under enable_x64
2020-08-19 08:28:56 -07:00
Roy Frostig
d778a6d074
move experimental.jaxpr_stats to jaxpr_util
2020-08-18 18:07:38 -07:00