Yash Katariya
7f192c1946
Cache the expensive computations in GDA. For example get_shard_indices_replica_ids
can be the same for multiple variables in a neural network (global_shape, mesh_axes and global_mesh) can be the same
...
Note that the first time will be a little slow. The below timings you are seeing shows the caching working because the benchmark is running for multiple iterations and then the time is averaged over the number of iterations.
```
name time/op
gda_construction_callback_(4, 2)_['x', 'y'] 4.50ms ±10%
gda_construction_raw_(256, 8)_['x', 'y'] 5.82ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x', 'y'] 2.95ms ± 6%
indices_replica_id_calc_cached_(256, 8)_['x', 'y'] 28.7µs ± 1%
gda_construction_callback_(4, 2)_[None] 31.9ms ±20%
gda_construction_raw_(256, 8)_[None] 5.85ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_[None] 1.75ms ± 1%
indices_replica_id_calc_cached_(256, 8)_[None] 29.0µs ± 4%
gda_construction_callback_(4, 2)_['x'] 8.40ms ± 4%
gda_construction_raw_(256, 8)_['x'] 5.48ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x'] 1.89ms ± 1%
indices_replica_id_calc_cached_(256, 8)_['x'] 29.0µs ± 4%
gda_construction_callback_(4, 2)_['y'] 15.3ms ± 6%
gda_construction_raw_(256, 8)_['y'] 5.66ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_['y'] 1.82ms ± 2%
indices_replica_id_calc_cached_(256, 8)_['y'] 29.4µs ± 3%
gda_construction_callback_(4, 2)_[('x', 'y')] 4.29ms ± 5%
gda_construction_raw_(256, 8)_[('x', 'y')] 5.61ms ± 7%
indices_replica_id_calc__uncached_(256, 8)_[('x', 'y')] 3.81ms ±10%
indices_replica_id_calc_cached_(256, 8)_[('x', 'y')] 29.0µs ± 5%
gda_construction_raw_(128, 8)_['x', 'y'] 2.42ms ± 1%
indices_replica_id_calc__uncached_(128, 8)_['x', 'y'] 1.14ms ±11%
indices_replica_id_calc_cached_(128, 8)_['x', 'y'] 19.9µs ± 1%
gda_construction_raw_(4, 2)_['x', 'y'] 46.7µs ± 0%
indices_replica_id_calc__uncached_(4, 2)_['x', 'y'] 153µs ± 4%
indices_replica_id_calc_cached_(4, 2)_['x', 'y'] 11.1µs ± 8%
gda_construction_raw_(16, 4)_['x', 'y'] 164µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_['x', 'y'] 212µs ± 3%
indices_replica_id_calc_cached_(16, 4)_['x', 'y'] 11.3µs ± 1%
gda_construction_raw_(16, 4)_[('x', 'y')] 163µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_[('x', 'y')] 210µs ± 2%
indices_replica_id_calc_cached_(16, 4)_[('x', 'y')] 11.6µs ± 8%
```
PiperOrigin-RevId: 422639127
2022-01-18 13:53:52 -08:00
jax authors
6e35d55c41
Merge pull request #9230 from NeilGirdhar:fix_annotation
...
PiperOrigin-RevId: 422638932
2022-01-18 13:49:11 -08:00
Peter Hawkins
329de7c9cc
Only use config.x64_enabled as the memo cache key for canonicalize_dtype, not any other fields.
...
This saves the time to repeatedly build a tuple as a cache key. Reduces the time for CustomLinearSolveTest.test_custom_linear_solve_pytree on my workstation from 110s to 85s.
PiperOrigin-RevId: 422632700
2022-01-18 13:24:15 -08:00
jax authors
80dba64c8e
Merge pull request #9179 from allenlavoie:main
...
PiperOrigin-RevId: 422628545
2022-01-18 13:04:50 -08:00
jax authors
ad4b9f4948
Merge pull request #9094 from jakevdp:fix-update-slice
...
PiperOrigin-RevId: 422620656
2022-01-18 12:35:34 -08:00
jax authors
edf6efc2d8
Merge pull request #9209 from jakevdp:bcoo-sort-indices
...
PiperOrigin-RevId: 422620648
2022-01-18 12:30:38 -08:00
Neil Girdhar
b424d40c2b
Correct lax.while_loop type annotation
2022-01-18 15:14:02 -05:00
jax authors
e9f89d47f0
Merge pull request #9227 from mattjj:update-pypi
...
PiperOrigin-RevId: 422609723
2022-01-18 11:45:06 -08:00
Matthew Johnson
0066533dae
update version and changelog for pypi
2022-01-18 11:38:32 -08:00
Matthew Johnson
dc484bf450
Copybara import of the project:
...
--
06deb73c9be01cedc000efe7b3eb72d68615471a by Matthew Johnson <mattjj@google.com>:
cache initial-style jaxpr transformations
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/9196 from mattjj:issue3847 06deb73c9be01cedc000efe7b3eb72d68615471a
PiperOrigin-RevId: 422604879
2022-01-18 11:25:13 -08:00
Peter Hawkins
4c423c36d0
Speed up check_jaxpr().
...
(check_jaxpr() is only used when debugging.)
Don't eagerly pretty print jaxprs: only do so if we are going to raise an error.
Don't eagerly form error messages. Delete typecheck_assert.
PiperOrigin-RevId: 422594126
jax-v0.2.27
2022-01-18 10:42:14 -08:00
jax authors
e30b96cf80
Merge pull request #9201 from LenaMartens:changelist/420794552
...
PiperOrigin-RevId: 422589737
2022-01-18 10:24:20 -08:00
Jake VanderPlas
255d4b1b73
jax2tf: disable shape_poly_test for dynamic_update_slice
2022-01-18 10:13:09 -08:00
Jake VanderPlas
4832f09981
lax.dynamic_update_slice: fix batching rule
2022-01-18 10:07:22 -08:00
Lena Martens
8ea85769ea
Checkify: add way to disable categories of errors.
...
By default only user_asserts are lifted into the checked function.
2022-01-18 17:59:50 +00:00
Jake VanderPlas
16d6c4d027
[sparse] add bcoo_sort_indices
2022-01-18 09:59:26 -08:00
jax authors
6411f8a033
Merge pull request #9184 from jakevdp:unique-nan
...
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
jax authors
bebe9845a8
Merge pull request #9205 from jakevdp:einsum-tuple
...
PiperOrigin-RevId: 422013671
2022-01-15 02:46:56 -08:00
Yash Katariya
b92db58eaf
Canonicalize parsed partition spec before passing to lower_mesh_computation. Creates a new data structure CanonicalizedParsedPartitionSpec
which strips empty tuples from the end of parsed partitions to canonicalize the specs so that P(None)
and None
for example in in_axis_resources are equivalent.
...
I have been bit by this 3 times and its about time I fix this. This also fixes a bug where fully replicated values are allowed with non-contiguous meshes (in this case P(None) and None) were not equal.
PiperOrigin-RevId: 421918164
2022-01-14 14:52:43 -08:00
Peter Hawkins
b509aae2a2
Split lax_control_flow_test into three separate tests.
...
Split the custom root and custom linear solve tests into separate test files.
Disable two slow custom linear solve tests.
Add a few jit decorators to slow tests in lax_control_flow_test.
PiperOrigin-RevId: 421901487
2022-01-14 13:36:34 -08:00
jax authors
c9169fa0d5
Merge pull request #9189 from gnecula:tf_reduce_window
...
PiperOrigin-RevId: 421875035
2022-01-14 11:35:16 -08:00
Jake VanderPlas
77d60cf4dd
einsum: clarify use of precision.
2022-01-14 11:08:13 -08:00
jax authors
7f07f1b86b
Merge pull request #9200 from google:LenaMartens-patch-1
...
PiperOrigin-RevId: 421845558
2022-01-14 09:30:24 -08:00
Lena Martens
f591d0b2e9
Add ensure_compile_time_eval docstring to docs
2022-01-14 11:18:40 +00:00
Jake VanderPlas
bd157cf056
jnp.unique: properly handle NaN values
2022-01-13 15:54:07 -08:00
jax authors
3b374e7dd9
Merge pull request #9127 from jakevdp:searchsorted-complex
...
PiperOrigin-RevId: 421672606
2022-01-13 15:25:41 -08:00
Jake VanderPlas
8ca10ea53f
searchsorted: use correct ordering for complex inputs
2022-01-13 13:45:59 -08:00
jax authors
cd73a4195f
Merge pull request #9178 from jakevdp:sort-corner-cases
...
PiperOrigin-RevId: 421646783
2022-01-13 13:33:02 -08:00
Jake VanderPlas
d8bdd9a19d
lax.sort: regularize handling of -0.0 and -NaN
2022-01-13 13:03:41 -08:00
Yash Katariya
0532a63261
Optimizations for GDA to make creating GDA faster.
...
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )
* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.
* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)
```
name old time/op new time/op delta
gda_construction_callback_(4, 2)_['x', 'y'] 4.77ms ± 5% 4.74ms ± 5% ~ (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y'] 17.9ms ± 5% 9.0ms ± 2% -49.92% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y'] 11.4ms ± 2% 2.9ms ± 2% -74.52% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None] 34.0ms ±20% 30.5ms ± 2% ~ (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None] 15.9ms ± 2% 7.7ms ± 3% -51.56% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None] 9.39ms ± 3% 1.74ms ± 2% -81.44% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x'] 8.87ms ± 2% 8.92ms ± 3% ~ (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x'] 16.4ms ± 2% 7.7ms ± 1% -52.66% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x'] 9.85ms ± 1% 1.90ms ± 2% -80.68% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y'] 15.9ms ± 3% 16.0ms ± 5% ~ (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y'] 15.8ms ± 3% 7.6ms ± 1% -52.04% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y'] 9.29ms ± 1% 1.78ms ± 1% -80.79% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')] 4.65ms ± 2% 4.62ms ± 3% ~ (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')] 18.6ms ± 3% 9.7ms ± 5% -47.76% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')] 11.8ms ± 4% 3.5ms ± 2% -70.28% (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y'] 8.54ms ± 1% 4.03ms ± 2% -52.84% (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y'] 5.40ms ± 4% 1.10ms ± 1% -79.69% (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y'] 173µs ± 1% 193µs ± 3% +11.63% (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y'] 127µs ± 1% 147µs ± 1% +15.57% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yuanzhong Xu
2e4687a62e
PartitionSpec: allow partially specified sharding
...
PiperOrigin-RevId: 421603194
2022-01-13 10:35:15 -08:00
George Necula
5bfe1852a4
[jax2tf] Add jax2tf_associative_scan_reductions flag
...
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
jax authors
f0e4f0472d
Merge pull request #9186 from froystig:get-default-rng
...
PiperOrigin-RevId: 421454645
2022-01-12 19:48:04 -08:00
Roy Frostig
026b91b85d
add random.default_prng_impl
to retrieve the default PRNG implementation
2022-01-12 19:13:14 -08:00
jax authors
436ce7904c
Merge pull request #9175 from froystig:custom-xform-wrappers-forward-attrs
...
PiperOrigin-RevId: 421449851
2022-01-12 19:11:22 -08:00
jax authors
f08bb50bfa
Merge pull request #8869 from mbmccoy:issue8744
...
PiperOrigin-RevId: 421404424
2022-01-12 15:07:47 -08:00
Peter Hawkins
dbf03d2ee6
[JAX] Add support for generating "equation profiles" in JAX.
...
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.
For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
primitive: Total 6062.0
1509.0 (24.89%): mul
936.0 (15.44%): add
589.0 ( 9.72%): reshape
492.0 ( 8.12%): div
485.0 ( 8.00%): broadcast_in_dim
330.0 ( 5.44%): reduce_sum
322.0 ( 5.31%): integer_pow
230.0 ( 3.79%): add_any
174.0 ( 2.87%): convert_element_type
160.0 ( 2.64%): select
158.0 ( 2.61%): conv_general_dilated
116.0 ( 1.91%): sub
110.0 ( 1.81%): eq
110.0 ( 1.81%): neg
104.0 ( 1.72%): max
53.0 ( 0.87%): rsqrt
52.0 ( 0.86%): rev
49.0 ( 0.81%): custom_jvp_call_jaxpr
49.0 ( 0.81%): gt
5.0 (0.082%): xla_call
4.0 (0.066%): min
3.0 (0.049%): dot_general
3.0 (0.049%): lt
2.0 (0.033%): cos
2.0 (0.033%): exp
2.0 (0.033%): iota
2.0 (0.033%): log
2.0 (0.033%): psum
2.0 (0.033%): reduce_max
2.0 (0.033%): stop_gradient
1.0 (0.016%): argmax
1.0 (0.016%): reduce_window_max
1.0 (0.016%): select_and_scatter_add
1.0 (0.016%): transpose
1.0 (0.016%): xla_pmap
```
Or the lines of code that generated the most equations:
```
$ pprof --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
flat flat% sum% cum cum%
1537 25.35% 25.35% 1537 25.35% _compute_stats
1484 24.48% 49.84% 1484 24.48% _normalize
849 14.01% 63.84% 6062 100% __call__
644 10.62% 74.46% 644 10.62% <unknown>
483 7.97% 82.43% 483 7.97% <unknown>
392 6.47% 88.90% 6061 100% train_step
324 5.34% 94.24% 324 5.34% <unknown>
161 2.66% 96.90% 161 2.66% <unknown>
57 0.94% 97.84% 4292 70.80% loss_fn
52 0.86% 98.70% 52 0.86% schedule
39 0.64% 99.34% 39 0.64% softmax_cross_entropy
8 0.13% 99.47% 30 0.49% compute_metrics
6 0.099% 99.57% 61 1.01% cross_entropy_loss
1 0.016% 99.59% 1321 21.79% apply_gradients
1 0.016% 99.60% 6062 100% train_and_evaluate
0 0% 99.60% 6062 100% <unknown>
0 0% 99.60% 6062 100% __init__
0 0% 99.60% 3872 63.87% _call_wrapped_method
0 0% 99.60% 6062 100% _run_and_get_tests_result
0 0% 99.60% 6062 100% _run_code_in_main
0 0% 99.60% 6062 100% _run_in_app
0 0% 99.60% 6062 100% _run_main
0 0% 99.60% 3872 63.87% apply
0 0% 99.60% 161 2.66% apply_updates
0 0% 99.60% 6062 100% main
0 0% 99.60% 6062 100% main_function
0 0% 99.60% 6062 100% run
0 0% 99.60% 6062 100% runTests
0 0% 99.60% 6062 100% run_filename_as_main
0 0% 99.60% 6062 100% run_tests
0 0% 99.60% 3872 63.87% scope_fn
0 0% 99.60% 6062 100% test_train_and_evaluate
0 0% 99.60% 1159 19.12% update_fn
0 0% 99.60% 3872 63.87% wrapped_fn
0 0% 99.60% 3872 63.87% wrapped_module_method
0 0% 99.60% 3872 63.87% wrapper
```
I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```
[XLA:Python] Add helpers to Traceback and for working with pprof profiles.
* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.
PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Mike McCoy
2e5ab11652
Resolves issue 8744
2022-01-12 21:10:45 +00:00
Allen Lavoie
65ecea71fd
Remove trailing whitespace
2022-01-12 15:26:55 -05:00
Allen Lavoie
bd061e6e3c
Fix tf backprop through bfloat16 jax2tf
2022-01-12 14:54:41 -05:00
Allen Lavoie
ad85eab6ea
Add a jax2tf bfloat16 backprop test
2022-01-12 14:52:40 -05:00
jax authors
2e375f04a6
Merge pull request #8915 from mattjj:post-process-revisions
...
PiperOrigin-RevId: 421311438
2022-01-12 08:57:53 -08:00
Matthew Johnson
08aec823fd
fix a custom_vjp post_process bug, related cleanups
...
related to #8783 , doesn't completely fix it
2022-01-12 07:51:50 -08:00
Roy Frostig
ddc1c3e9bd
enable custom transformation "stacking"
...
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:
```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...
@f.def_transpose
def f_t(y): ...
@f.defjvp
def f_jvp(x, tx): ...
```
In particular:
* Forward `def*` methods on custom transformations.
* Have unary `def*` methods return their argument so that, when used
as decorators, they do not replace their target with `None`.
* Fix a bug in the use of `functools.update_wrapper`: previously a
wrapper would overwrite its own attributes with those of the target
callable (including its reference to the target callable).
2022-01-11 17:55:08 -08:00
jax authors
4dd1f001c6
Merge pull request #9162 from mattjj:simplify-memoize
...
PiperOrigin-RevId: 421154196
2022-01-11 16:34:20 -08:00
Yash Katariya
fbb8b9f8c6
Benchmarks for GDA. Also move create_global_mesh
to test_utils since it was replicated in a lot of places.
...
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
jax authors
f235edbf5c
Merge pull request #9134 from froystig:custom-transpose
...
PiperOrigin-RevId: 421119056
2022-01-11 14:00:45 -08:00
Roy Frostig
1709e06800
introduce custom_transpose and a corresponding primitive
...
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
jax authors
a30ec029ee
Merge pull request #9164 from mattjj:checkify-tweaks
...
PiperOrigin-RevId: 421098252
2022-01-11 12:36:31 -08:00
Matthew Johnson
6850833c3a
checkify: tweak some organization and names
2022-01-10 21:29:12 -08:00