10211 Commits

Author SHA1 Message Date
Yash Katariya
7f192c1946 Cache the expensive computations in GDA. For example get_shard_indices_replica_ids can be the same for multiple variables in a neural network (global_shape, mesh_axes and global_mesh) can be the same
Note that the first time will be a little slow. The below timings you are seeing shows the caching working because the benchmark is running for multiple iterations and then the time is averaged over the number of iterations.

```
name                                                     time/op
gda_construction_callback_(4, 2)_['x', 'y']              4.50ms ±10%
gda_construction_raw_(256, 8)_['x', 'y']                 5.82ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x', 'y']    2.95ms ± 6%
indices_replica_id_calc_cached_(256, 8)_['x', 'y']       28.7µs ± 1%
gda_construction_callback_(4, 2)_[None]                  31.9ms ±20%
gda_construction_raw_(256, 8)_[None]                     5.85ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_[None]        1.75ms ± 1%
indices_replica_id_calc_cached_(256, 8)_[None]           29.0µs ± 4%
gda_construction_callback_(4, 2)_['x']                   8.40ms ± 4%
gda_construction_raw_(256, 8)_['x']                      5.48ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x']         1.89ms ± 1%
indices_replica_id_calc_cached_(256, 8)_['x']            29.0µs ± 4%
gda_construction_callback_(4, 2)_['y']                   15.3ms ± 6%
gda_construction_raw_(256, 8)_['y']                      5.66ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_['y']         1.82ms ± 2%
indices_replica_id_calc_cached_(256, 8)_['y']            29.4µs ± 3%
gda_construction_callback_(4, 2)_[('x', 'y')]            4.29ms ± 5%
gda_construction_raw_(256, 8)_[('x', 'y')]               5.61ms ± 7%
indices_replica_id_calc__uncached_(256, 8)_[('x', 'y')]  3.81ms ±10%
indices_replica_id_calc_cached_(256, 8)_[('x', 'y')]     29.0µs ± 5%
gda_construction_raw_(128, 8)_['x', 'y']                 2.42ms ± 1%
indices_replica_id_calc__uncached_(128, 8)_['x', 'y']    1.14ms ±11%
indices_replica_id_calc_cached_(128, 8)_['x', 'y']       19.9µs ± 1%
gda_construction_raw_(4, 2)_['x', 'y']                   46.7µs ± 0%
indices_replica_id_calc__uncached_(4, 2)_['x', 'y']       153µs ± 4%
indices_replica_id_calc_cached_(4, 2)_['x', 'y']         11.1µs ± 8%
gda_construction_raw_(16, 4)_['x', 'y']                   164µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_['x', 'y']      212µs ± 3%
indices_replica_id_calc_cached_(16, 4)_['x', 'y']        11.3µs ± 1%
gda_construction_raw_(16, 4)_[('x', 'y')]                 163µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_[('x', 'y')]    210µs ± 2%
indices_replica_id_calc_cached_(16, 4)_[('x', 'y')]      11.6µs ± 8%
```

PiperOrigin-RevId: 422639127
2022-01-18 13:53:52 -08:00
jax authors
6e35d55c41 Merge pull request #9230 from NeilGirdhar:fix_annotation
PiperOrigin-RevId: 422638932
2022-01-18 13:49:11 -08:00
Peter Hawkins
329de7c9cc Only use config.x64_enabled as the memo cache key for canonicalize_dtype, not any other fields.
This saves the time to repeatedly build a tuple as a cache key. Reduces the time for CustomLinearSolveTest.test_custom_linear_solve_pytree on my workstation from 110s to 85s.

PiperOrigin-RevId: 422632700
2022-01-18 13:24:15 -08:00
jax authors
80dba64c8e Merge pull request #9179 from allenlavoie:main
PiperOrigin-RevId: 422628545
2022-01-18 13:04:50 -08:00
jax authors
ad4b9f4948 Merge pull request #9094 from jakevdp:fix-update-slice
PiperOrigin-RevId: 422620656
2022-01-18 12:35:34 -08:00
jax authors
edf6efc2d8 Merge pull request #9209 from jakevdp:bcoo-sort-indices
PiperOrigin-RevId: 422620648
2022-01-18 12:30:38 -08:00
Neil Girdhar
b424d40c2b Correct lax.while_loop type annotation 2022-01-18 15:14:02 -05:00
jax authors
e9f89d47f0 Merge pull request #9227 from mattjj:update-pypi
PiperOrigin-RevId: 422609723
2022-01-18 11:45:06 -08:00
Matthew Johnson
0066533dae update version and changelog for pypi 2022-01-18 11:38:32 -08:00
Matthew Johnson
dc484bf450 Copybara import of the project:
--
06deb73c9be01cedc000efe7b3eb72d68615471a by Matthew Johnson <mattjj@google.com>:

cache initial-style jaxpr transformations

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/9196 from mattjj:issue3847 06deb73c9be01cedc000efe7b3eb72d68615471a
PiperOrigin-RevId: 422604879
2022-01-18 11:25:13 -08:00
Peter Hawkins
4c423c36d0 Speed up check_jaxpr().
(check_jaxpr() is only used when debugging.)

Don't eagerly pretty print jaxprs: only do so if we are going to raise an error.
Don't eagerly form error messages. Delete typecheck_assert.

PiperOrigin-RevId: 422594126
jax-v0.2.27
2022-01-18 10:42:14 -08:00
jax authors
e30b96cf80 Merge pull request #9201 from LenaMartens:changelist/420794552
PiperOrigin-RevId: 422589737
2022-01-18 10:24:20 -08:00
Jake VanderPlas
255d4b1b73 jax2tf: disable shape_poly_test for dynamic_update_slice 2022-01-18 10:13:09 -08:00
Jake VanderPlas
4832f09981 lax.dynamic_update_slice: fix batching rule 2022-01-18 10:07:22 -08:00
Lena Martens
8ea85769ea Checkify: add way to disable categories of errors.
By default only user_asserts are lifted into the checked function.
2022-01-18 17:59:50 +00:00
Jake VanderPlas
16d6c4d027 [sparse] add bcoo_sort_indices 2022-01-18 09:59:26 -08:00
jax authors
6411f8a033 Merge pull request #9184 from jakevdp:unique-nan
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
jax authors
bebe9845a8 Merge pull request #9205 from jakevdp:einsum-tuple
PiperOrigin-RevId: 422013671
2022-01-15 02:46:56 -08:00
Yash Katariya
b92db58eaf Canonicalize parsed partition spec before passing to lower_mesh_computation. Creates a new data structure CanonicalizedParsedPartitionSpec which strips empty tuples from the end of parsed partitions to canonicalize the specs so that P(None) and None for example in in_axis_resources are equivalent.
I have been bit by this 3 times and its about time I fix this. This also fixes a bug where fully replicated values are allowed with non-contiguous meshes (in this case P(None) and None) were not equal.

PiperOrigin-RevId: 421918164
2022-01-14 14:52:43 -08:00
Peter Hawkins
b509aae2a2 Split lax_control_flow_test into three separate tests.
Split the custom root and custom linear solve tests into separate test files.

Disable two slow custom linear solve tests.
Add a few jit decorators to slow tests in lax_control_flow_test.

PiperOrigin-RevId: 421901487
2022-01-14 13:36:34 -08:00
jax authors
c9169fa0d5 Merge pull request #9189 from gnecula:tf_reduce_window
PiperOrigin-RevId: 421875035
2022-01-14 11:35:16 -08:00
Jake VanderPlas
77d60cf4dd einsum: clarify use of precision. 2022-01-14 11:08:13 -08:00
jax authors
7f07f1b86b Merge pull request #9200 from google:LenaMartens-patch-1
PiperOrigin-RevId: 421845558
2022-01-14 09:30:24 -08:00
Lena Martens
f591d0b2e9
Add ensure_compile_time_eval docstring to docs 2022-01-14 11:18:40 +00:00
Jake VanderPlas
bd157cf056 jnp.unique: properly handle NaN values 2022-01-13 15:54:07 -08:00
jax authors
3b374e7dd9 Merge pull request #9127 from jakevdp:searchsorted-complex
PiperOrigin-RevId: 421672606
2022-01-13 15:25:41 -08:00
Jake VanderPlas
8ca10ea53f searchsorted: use correct ordering for complex inputs 2022-01-13 13:45:59 -08:00
jax authors
cd73a4195f Merge pull request #9178 from jakevdp:sort-corner-cases
PiperOrigin-RevId: 421646783
2022-01-13 13:33:02 -08:00
Jake VanderPlas
d8bdd9a19d lax.sort: regularize handling of -0.0 and -NaN 2022-01-13 13:03:41 -08:00
Yash Katariya
0532a63261 Optimizations for GDA to make creating GDA faster.
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )

* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.

* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)

```
name                                           old time/op             new time/op             delta
gda_construction_callback_(4, 2)_['x', 'y']    4.77ms ± 5%             4.74ms ± 5%     ~           (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y']       17.9ms ± 5%              9.0ms ± 2%  -49.92%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y']    11.4ms ± 2%              2.9ms ± 2%  -74.52%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None]        34.0ms ±20%             30.5ms ± 2%     ~             (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None]           15.9ms ± 2%              7.7ms ± 3%  -51.56%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None]        9.39ms ± 3%             1.74ms ± 2%  -81.44%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x']         8.87ms ± 2%             8.92ms ± 3%     ~             (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x']            16.4ms ± 2%              7.7ms ± 1%  -52.66%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x']         9.85ms ± 1%             1.90ms ± 2%  -80.68%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y']         15.9ms ± 3%             16.0ms ± 5%     ~             (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y']            15.8ms ± 3%              7.6ms ± 1%  -52.04%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y']         9.29ms ± 1%             1.78ms ± 1%  -80.79%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')]  4.65ms ± 2%             4.62ms ± 3%     ~            (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')]     18.6ms ± 3%              9.7ms ± 5%  -47.76%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')]  11.8ms ± 4%              3.5ms ± 2%  -70.28%          (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y']       8.54ms ± 1%             4.03ms ± 2%  -52.84%          (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y']    5.40ms ± 4%             1.10ms ± 1%  -79.69%          (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y']          173µs ± 1%              193µs ± 3%  +11.63%          (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y']       127µs ± 1%              147µs ± 1%  +15.57%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yuanzhong Xu
2e4687a62e PartitionSpec: allow partially specified sharding
PiperOrigin-RevId: 421603194
2022-01-13 10:35:15 -08:00
George Necula
5bfe1852a4 [jax2tf] Add jax2tf_associative_scan_reductions flag
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
jax authors
f0e4f0472d Merge pull request #9186 from froystig:get-default-rng
PiperOrigin-RevId: 421454645
2022-01-12 19:48:04 -08:00
Roy Frostig
026b91b85d add random.default_prng_impl to retrieve the default PRNG implementation 2022-01-12 19:13:14 -08:00
jax authors
436ce7904c Merge pull request #9175 from froystig:custom-xform-wrappers-forward-attrs
PiperOrigin-RevId: 421449851
2022-01-12 19:11:22 -08:00
jax authors
f08bb50bfa Merge pull request #8869 from mbmccoy:issue8744
PiperOrigin-RevId: 421404424
2022-01-12 15:07:47 -08:00
Peter Hawkins
dbf03d2ee6 [JAX] Add support for generating "equation profiles" in JAX.
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.

For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
 primitive: Total 6062.0
            1509.0 (24.89%): mul
             936.0 (15.44%): add
             589.0 ( 9.72%): reshape
             492.0 ( 8.12%): div
             485.0 ( 8.00%): broadcast_in_dim
             330.0 ( 5.44%): reduce_sum
             322.0 ( 5.31%): integer_pow
             230.0 ( 3.79%): add_any
             174.0 ( 2.87%): convert_element_type
             160.0 ( 2.64%): select
             158.0 ( 2.61%): conv_general_dilated
             116.0 ( 1.91%): sub
             110.0 ( 1.81%): eq
             110.0 ( 1.81%): neg
             104.0 ( 1.72%): max
              53.0 ( 0.87%): rsqrt
              52.0 ( 0.86%): rev
              49.0 ( 0.81%): custom_jvp_call_jaxpr
              49.0 ( 0.81%): gt
               5.0 (0.082%): xla_call
               4.0 (0.066%): min
               3.0 (0.049%): dot_general
               3.0 (0.049%): lt
               2.0 (0.033%): cos
               2.0 (0.033%): exp
               2.0 (0.033%): iota
               2.0 (0.033%): log
               2.0 (0.033%): psum
               2.0 (0.033%): reduce_max
               2.0 (0.033%): stop_gradient
               1.0 (0.016%): argmax
               1.0 (0.016%): reduce_window_max
               1.0 (0.016%): select_and_scatter_add
               1.0 (0.016%): transpose
               1.0 (0.016%): xla_pmap
```

Or the lines of code that generated the most equations:
```
$ pprof  --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
      flat  flat%   sum%        cum   cum%
      1537 25.35% 25.35%       1537 25.35%  _compute_stats
      1484 24.48% 49.84%       1484 24.48%  _normalize
       849 14.01% 63.84%       6062   100%  __call__
       644 10.62% 74.46%        644 10.62%  <unknown>
       483  7.97% 82.43%        483  7.97%  <unknown>
       392  6.47% 88.90%       6061   100%  train_step
       324  5.34% 94.24%        324  5.34%  <unknown>
       161  2.66% 96.90%        161  2.66%  <unknown>
        57  0.94% 97.84%       4292 70.80%  loss_fn
        52  0.86% 98.70%         52  0.86%  schedule
        39  0.64% 99.34%         39  0.64%  softmax_cross_entropy
         8  0.13% 99.47%         30  0.49%  compute_metrics
         6 0.099% 99.57%         61  1.01%  cross_entropy_loss
         1 0.016% 99.59%       1321 21.79%  apply_gradients
         1 0.016% 99.60%       6062   100%  train_and_evaluate
         0     0% 99.60%       6062   100%  <unknown>
         0     0% 99.60%       6062   100%  __init__
         0     0% 99.60%       3872 63.87%  _call_wrapped_method
         0     0% 99.60%       6062   100%  _run_and_get_tests_result
         0     0% 99.60%       6062   100%  _run_code_in_main
         0     0% 99.60%       6062   100%  _run_in_app
         0     0% 99.60%       6062   100%  _run_main
         0     0% 99.60%       3872 63.87%  apply
         0     0% 99.60%        161  2.66%  apply_updates
         0     0% 99.60%       6062   100%  main
         0     0% 99.60%       6062   100%  main_function
         0     0% 99.60%       6062   100%  run
         0     0% 99.60%       6062   100%  runTests
         0     0% 99.60%       6062   100%  run_filename_as_main
         0     0% 99.60%       6062   100%  run_tests
         0     0% 99.60%       3872 63.87%  scope_fn
         0     0% 99.60%       6062   100%  test_train_and_evaluate
         0     0% 99.60%       1159 19.12%  update_fn
         0     0% 99.60%       3872 63.87%  wrapped_fn
         0     0% 99.60%       3872 63.87%  wrapped_module_method
         0     0% 99.60%       3872 63.87%  wrapper
```

I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```

[XLA:Python] Add helpers to Traceback and for working with pprof profiles.

* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.

PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Mike McCoy
2e5ab11652 Resolves issue 8744 2022-01-12 21:10:45 +00:00
Allen Lavoie
65ecea71fd
Remove trailing whitespace 2022-01-12 15:26:55 -05:00
Allen Lavoie
bd061e6e3c
Fix tf backprop through bfloat16 jax2tf 2022-01-12 14:54:41 -05:00
Allen Lavoie
ad85eab6ea
Add a jax2tf bfloat16 backprop test 2022-01-12 14:52:40 -05:00
jax authors
2e375f04a6 Merge pull request #8915 from mattjj:post-process-revisions
PiperOrigin-RevId: 421311438
2022-01-12 08:57:53 -08:00
Matthew Johnson
08aec823fd fix a custom_vjp post_process bug, related cleanups
related to #8783, doesn't completely fix it
2022-01-12 07:51:50 -08:00
Roy Frostig
ddc1c3e9bd enable custom transformation "stacking"
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:

```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...

@f.def_transpose
def f_t(y): ...

@f.defjvp
def f_jvp(x, tx): ...
```

In particular:

* Forward `def*` methods on custom transformations.

* Have unary `def*` methods return their argument so that, when used
  as decorators, they do not replace their target with `None`.

* Fix a bug in the use of `functools.update_wrapper`: previously a
  wrapper would overwrite its own attributes with those of the target
  callable (including its reference to the target callable).
2022-01-11 17:55:08 -08:00
jax authors
4dd1f001c6 Merge pull request #9162 from mattjj:simplify-memoize
PiperOrigin-RevId: 421154196
2022-01-11 16:34:20 -08:00
Yash Katariya
fbb8b9f8c6 Benchmarks for GDA. Also move create_global_mesh to test_utils since it was replicated in a lot of places.
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
jax authors
f235edbf5c Merge pull request #9134 from froystig:custom-transpose
PiperOrigin-RevId: 421119056
2022-01-11 14:00:45 -08:00
Roy Frostig
1709e06800 introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
jax authors
a30ec029ee Merge pull request #9164 from mattjj:checkify-tweaks
PiperOrigin-RevId: 421098252
2022-01-11 12:36:31 -08:00
Matthew Johnson
6850833c3a checkify: tweak some organization and names 2022-01-10 21:29:12 -08:00