10193 Commits

Author SHA1 Message Date
Lena Martens
8ea85769ea Checkify: add way to disable categories of errors.
By default only user_asserts are lifted into the checked function.
2022-01-18 17:59:50 +00:00
jax authors
6411f8a033 Merge pull request #9184 from jakevdp:unique-nan
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
jax authors
bebe9845a8 Merge pull request #9205 from jakevdp:einsum-tuple
PiperOrigin-RevId: 422013671
2022-01-15 02:46:56 -08:00
Yash Katariya
b92db58eaf Canonicalize parsed partition spec before passing to lower_mesh_computation. Creates a new data structure CanonicalizedParsedPartitionSpec which strips empty tuples from the end of parsed partitions to canonicalize the specs so that P(None) and None for example in in_axis_resources are equivalent.
I have been bit by this 3 times and its about time I fix this. This also fixes a bug where fully replicated values are allowed with non-contiguous meshes (in this case P(None) and None) were not equal.

PiperOrigin-RevId: 421918164
2022-01-14 14:52:43 -08:00
Peter Hawkins
b509aae2a2 Split lax_control_flow_test into three separate tests.
Split the custom root and custom linear solve tests into separate test files.

Disable two slow custom linear solve tests.
Add a few jit decorators to slow tests in lax_control_flow_test.

PiperOrigin-RevId: 421901487
2022-01-14 13:36:34 -08:00
jax authors
c9169fa0d5 Merge pull request #9189 from gnecula:tf_reduce_window
PiperOrigin-RevId: 421875035
2022-01-14 11:35:16 -08:00
Jake VanderPlas
77d60cf4dd einsum: clarify use of precision. 2022-01-14 11:08:13 -08:00
jax authors
7f07f1b86b Merge pull request #9200 from google:LenaMartens-patch-1
PiperOrigin-RevId: 421845558
2022-01-14 09:30:24 -08:00
Lena Martens
f591d0b2e9
Add ensure_compile_time_eval docstring to docs 2022-01-14 11:18:40 +00:00
Jake VanderPlas
bd157cf056 jnp.unique: properly handle NaN values 2022-01-13 15:54:07 -08:00
jax authors
3b374e7dd9 Merge pull request #9127 from jakevdp:searchsorted-complex
PiperOrigin-RevId: 421672606
2022-01-13 15:25:41 -08:00
Jake VanderPlas
8ca10ea53f searchsorted: use correct ordering for complex inputs 2022-01-13 13:45:59 -08:00
jax authors
cd73a4195f Merge pull request #9178 from jakevdp:sort-corner-cases
PiperOrigin-RevId: 421646783
2022-01-13 13:33:02 -08:00
Jake VanderPlas
d8bdd9a19d lax.sort: regularize handling of -0.0 and -NaN 2022-01-13 13:03:41 -08:00
Yash Katariya
0532a63261 Optimizations for GDA to make creating GDA faster.
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )

* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.

* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)

```
name                                           old time/op             new time/op             delta
gda_construction_callback_(4, 2)_['x', 'y']    4.77ms ± 5%             4.74ms ± 5%     ~           (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y']       17.9ms ± 5%              9.0ms ± 2%  -49.92%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y']    11.4ms ± 2%              2.9ms ± 2%  -74.52%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None]        34.0ms ±20%             30.5ms ± 2%     ~             (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None]           15.9ms ± 2%              7.7ms ± 3%  -51.56%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None]        9.39ms ± 3%             1.74ms ± 2%  -81.44%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x']         8.87ms ± 2%             8.92ms ± 3%     ~             (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x']            16.4ms ± 2%              7.7ms ± 1%  -52.66%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x']         9.85ms ± 1%             1.90ms ± 2%  -80.68%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y']         15.9ms ± 3%             16.0ms ± 5%     ~             (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y']            15.8ms ± 3%              7.6ms ± 1%  -52.04%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y']         9.29ms ± 1%             1.78ms ± 1%  -80.79%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')]  4.65ms ± 2%             4.62ms ± 3%     ~            (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')]     18.6ms ± 3%              9.7ms ± 5%  -47.76%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')]  11.8ms ± 4%              3.5ms ± 2%  -70.28%          (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y']       8.54ms ± 1%             4.03ms ± 2%  -52.84%          (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y']    5.40ms ± 4%             1.10ms ± 1%  -79.69%          (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y']          173µs ± 1%              193µs ± 3%  +11.63%          (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y']       127µs ± 1%              147µs ± 1%  +15.57%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yuanzhong Xu
2e4687a62e PartitionSpec: allow partially specified sharding
PiperOrigin-RevId: 421603194
2022-01-13 10:35:15 -08:00
George Necula
5bfe1852a4 [jax2tf] Add jax2tf_associative_scan_reductions flag
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
jax authors
f0e4f0472d Merge pull request #9186 from froystig:get-default-rng
PiperOrigin-RevId: 421454645
2022-01-12 19:48:04 -08:00
Roy Frostig
026b91b85d add random.default_prng_impl to retrieve the default PRNG implementation 2022-01-12 19:13:14 -08:00
jax authors
436ce7904c Merge pull request #9175 from froystig:custom-xform-wrappers-forward-attrs
PiperOrigin-RevId: 421449851
2022-01-12 19:11:22 -08:00
jax authors
f08bb50bfa Merge pull request #8869 from mbmccoy:issue8744
PiperOrigin-RevId: 421404424
2022-01-12 15:07:47 -08:00
Peter Hawkins
dbf03d2ee6 [JAX] Add support for generating "equation profiles" in JAX.
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.

For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
 primitive: Total 6062.0
            1509.0 (24.89%): mul
             936.0 (15.44%): add
             589.0 ( 9.72%): reshape
             492.0 ( 8.12%): div
             485.0 ( 8.00%): broadcast_in_dim
             330.0 ( 5.44%): reduce_sum
             322.0 ( 5.31%): integer_pow
             230.0 ( 3.79%): add_any
             174.0 ( 2.87%): convert_element_type
             160.0 ( 2.64%): select
             158.0 ( 2.61%): conv_general_dilated
             116.0 ( 1.91%): sub
             110.0 ( 1.81%): eq
             110.0 ( 1.81%): neg
             104.0 ( 1.72%): max
              53.0 ( 0.87%): rsqrt
              52.0 ( 0.86%): rev
              49.0 ( 0.81%): custom_jvp_call_jaxpr
              49.0 ( 0.81%): gt
               5.0 (0.082%): xla_call
               4.0 (0.066%): min
               3.0 (0.049%): dot_general
               3.0 (0.049%): lt
               2.0 (0.033%): cos
               2.0 (0.033%): exp
               2.0 (0.033%): iota
               2.0 (0.033%): log
               2.0 (0.033%): psum
               2.0 (0.033%): reduce_max
               2.0 (0.033%): stop_gradient
               1.0 (0.016%): argmax
               1.0 (0.016%): reduce_window_max
               1.0 (0.016%): select_and_scatter_add
               1.0 (0.016%): transpose
               1.0 (0.016%): xla_pmap
```

Or the lines of code that generated the most equations:
```
$ pprof  --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
      flat  flat%   sum%        cum   cum%
      1537 25.35% 25.35%       1537 25.35%  _compute_stats
      1484 24.48% 49.84%       1484 24.48%  _normalize
       849 14.01% 63.84%       6062   100%  __call__
       644 10.62% 74.46%        644 10.62%  <unknown>
       483  7.97% 82.43%        483  7.97%  <unknown>
       392  6.47% 88.90%       6061   100%  train_step
       324  5.34% 94.24%        324  5.34%  <unknown>
       161  2.66% 96.90%        161  2.66%  <unknown>
        57  0.94% 97.84%       4292 70.80%  loss_fn
        52  0.86% 98.70%         52  0.86%  schedule
        39  0.64% 99.34%         39  0.64%  softmax_cross_entropy
         8  0.13% 99.47%         30  0.49%  compute_metrics
         6 0.099% 99.57%         61  1.01%  cross_entropy_loss
         1 0.016% 99.59%       1321 21.79%  apply_gradients
         1 0.016% 99.60%       6062   100%  train_and_evaluate
         0     0% 99.60%       6062   100%  <unknown>
         0     0% 99.60%       6062   100%  __init__
         0     0% 99.60%       3872 63.87%  _call_wrapped_method
         0     0% 99.60%       6062   100%  _run_and_get_tests_result
         0     0% 99.60%       6062   100%  _run_code_in_main
         0     0% 99.60%       6062   100%  _run_in_app
         0     0% 99.60%       6062   100%  _run_main
         0     0% 99.60%       3872 63.87%  apply
         0     0% 99.60%        161  2.66%  apply_updates
         0     0% 99.60%       6062   100%  main
         0     0% 99.60%       6062   100%  main_function
         0     0% 99.60%       6062   100%  run
         0     0% 99.60%       6062   100%  runTests
         0     0% 99.60%       6062   100%  run_filename_as_main
         0     0% 99.60%       6062   100%  run_tests
         0     0% 99.60%       3872 63.87%  scope_fn
         0     0% 99.60%       6062   100%  test_train_and_evaluate
         0     0% 99.60%       1159 19.12%  update_fn
         0     0% 99.60%       3872 63.87%  wrapped_fn
         0     0% 99.60%       3872 63.87%  wrapped_module_method
         0     0% 99.60%       3872 63.87%  wrapper
```

I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```

[XLA:Python] Add helpers to Traceback and for working with pprof profiles.

* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.

PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Mike McCoy
2e5ab11652 Resolves issue 8744 2022-01-12 21:10:45 +00:00
jax authors
2e375f04a6 Merge pull request #8915 from mattjj:post-process-revisions
PiperOrigin-RevId: 421311438
2022-01-12 08:57:53 -08:00
Matthew Johnson
08aec823fd fix a custom_vjp post_process bug, related cleanups
related to #8783, doesn't completely fix it
2022-01-12 07:51:50 -08:00
Roy Frostig
ddc1c3e9bd enable custom transformation "stacking"
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:

```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...

@f.def_transpose
def f_t(y): ...

@f.defjvp
def f_jvp(x, tx): ...
```

In particular:

* Forward `def*` methods on custom transformations.

* Have unary `def*` methods return their argument so that, when used
  as decorators, they do not replace their target with `None`.

* Fix a bug in the use of `functools.update_wrapper`: previously a
  wrapper would overwrite its own attributes with those of the target
  callable (including its reference to the target callable).
2022-01-11 17:55:08 -08:00
jax authors
4dd1f001c6 Merge pull request #9162 from mattjj:simplify-memoize
PiperOrigin-RevId: 421154196
2022-01-11 16:34:20 -08:00
Yash Katariya
fbb8b9f8c6 Benchmarks for GDA. Also move create_global_mesh to test_utils since it was replicated in a lot of places.
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
jax authors
f235edbf5c Merge pull request #9134 from froystig:custom-transpose
PiperOrigin-RevId: 421119056
2022-01-11 14:00:45 -08:00
Roy Frostig
1709e06800 introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
jax authors
a30ec029ee Merge pull request #9164 from mattjj:checkify-tweaks
PiperOrigin-RevId: 421098252
2022-01-11 12:36:31 -08:00
Matthew Johnson
6850833c3a checkify: tweak some organization and names 2022-01-10 21:29:12 -08:00
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Yash Katariya
7bc51879d4 Add tests for 0d fully replicated scalar input to pjit.
PiperOrigin-RevId: 420884601
2022-01-10 16:24:31 -08:00
jax authors
67723da38b Merge pull request #9143 from mattjj:fix-jaxpr-checking-error-messages
PiperOrigin-RevId: 420866862
2022-01-10 15:09:04 -08:00
Peter Hawkins
9bc6d1103e [JAX] Fix spurious inequality for two apparently equal PyTreeDefs.
When constructed via one path we were filling in the .custom field of nodes that weren't custom types.

Fixes https://github.com/google/jax/issues/9066

PiperOrigin-RevId: 420858917
2022-01-10 14:35:56 -08:00
Matthew Johnson
e321964245 de-duplicate util.memoize and util.cache
The only difference between the two was that
jax.config.jax_check_tracer_leaks disables the caching under util.cache
but not under util.memoize.

We could add that as an option on the same function if it turns out to
be important, but it seems unnecessary. Moreover there are only two
callers (in dtypes.py and in batching.py).

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-01-10 14:28:28 -08:00
Yash Katariya
e5e343c13e Speed up deserialization by not doing the shard indices calculation twice. This calculation is expensive because it happens on the global mesh size scale.
PiperOrigin-RevId: 420855444
2022-01-10 14:26:11 -08:00
Yash Katariya
93b4554101 Do shard indices calculations only once in callbacks by using GDA fast path because indices calculation is expensive as it happens on the global mesh (if the global mesh is large). The second calculation happens in _create_local_shards which is avoided if GDA fast path is enabled.
PiperOrigin-RevId: 420855232
2022-01-10 14:21:54 -08:00
Peter Hawkins
f7de16c818 Disable a pjit test that is failing on GPU.
PiperOrigin-RevId: 420854561
2022-01-10 14:17:17 -08:00
jax authors
a41d3e2892 Merge pull request #9123 from prazek:patch-1
PiperOrigin-RevId: 420783231
2022-01-10 09:51:23 -08:00
Peter Hawkins
0e940b66dd Fix spurious "donated buffers were not usable" warning when using MLIR.
PiperOrigin-RevId: 420782796
2022-01-10 09:47:00 -08:00
jax authors
977e142d55 Merge pull request #9154 from che-shr-cat:patch-1
PiperOrigin-RevId: 420782381
2022-01-10 09:42:36 -08:00
jax authors
8016026a06 Merge pull request #9128 from jakevdp:bcoo-metadata
PiperOrigin-RevId: 420782354
2022-01-10 09:37:48 -08:00
Peter Hawkins
5801079a4b Enable JAX->MLIR lowering by default.
Before this change, JAX produces HLO using the XLA:Python builder APIs. After this change JAX produces MHLO using MLIR:Python APIs, and converts the MHLO to HLO for compilation with XLA. This is a lateral shift that should have little immediate impact, but unlocks a number of interesting opportunities in the future (e.g., mixing MLIR dialects within a JAX program).

[XLA:Python] Pass MLIR input as a std::string to work around https://github.com/pybind/pybind11/issues/2765. A better fix would be to update pybind11 but that is hitting Windows-related hurdles; for now, just avoid relying on reference lifetime extension.

Brax: update test seeds to avoid test failures. Additional constant folding (canonicalization) in the MHLO lowering path seems to cause small numerical differences.
PiperOrigin-RevId: 420755696
2022-01-10 07:26:23 -08:00
jax authors
cdeced943b Merge pull request #9113 from che-shr-cat:main
PiperOrigin-RevId: 420747734
2022-01-10 06:41:31 -08:00
che-shr-cat
d2c6c06546
Fix DeviceArray class reference 2022-01-10 17:34:09 +03:00
che-shr-cat
78977d6f5a fix broken links and update texts in thinking_in_jax.ipynb 2022-01-10 16:19:57 +03:00
jax authors
ac5f1c4b24 Merge pull request #9126 from LenaMartens:changelist/418004472
PiperOrigin-RevId: 420716471
2022-01-10 03:25:27 -08:00
Matthew Johnson
3548e023ec fix jaxpr type checking error messages
The pretty-printing changes a few months ago defined variable names
based on the state in JaxprPpContext instances. But that meant incorrect
variable names could be printed in jaxpr type checking error messages.

This commit correctly threads through the context so as to provide
error messages with coherent variable names.
2022-01-09 20:07:58 -08:00