12468 Commits

Author SHA1 Message Date
Jake VanderPlas
37c249fc80 MAINT: unpin autodoc-typehints 2022-08-01 10:33:50 -07:00
jax authors
be8939771c Merge pull request #11675 from mattjj:new-remat-caching
PiperOrigin-RevId: 464543287
2022-08-01 08:38:28 -07:00
jax authors
9bd6fd66c2 Merge pull request #11691 from hawkinsp:xla
PiperOrigin-RevId: 464541736
2022-08-01 08:28:57 -07:00
Peter Hawkins
96cb47df32 Update XLA. 2022-08-01 12:10:27 +00:00
Sharad Vikram
c08b4ee6d9 Add jaxlib guards for debugging_primitives_test
PiperOrigin-RevId: 464453175
2022-07-31 21:23:22 -07:00
Yash Katariya
5112006903 Repr the op_sharding for printing.
PiperOrigin-RevId: 464304033
2022-07-30 15:25:26 -07:00
Sharad Vikram
11b206a18a Enable debugging primitives in pjit on CPU/GPU
PiperOrigin-RevId: 464208326
2022-07-29 20:10:27 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
Yash Katariya
2109c6ec8c Make is_compatible_aval an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
jax authors
6cffa720e7 Merge pull request #11670 from sharadmv:debugging-docs
PiperOrigin-RevId: 464145832
2022-07-29 13:26:56 -07:00
Sharad Vikram
decdca60c8 Change jaxdb->jdb and add option to force a backend 2022-07-29 12:51:27 -07:00
jax authors
21f632740c Merge pull request #11669 from LenaMartens:check-of-pjit
PiperOrigin-RevId: 464133301
2022-07-29 12:25:32 -07:00
Mehdi Amini
5a6cb438e8 Move MHLO to XLA
As part of the OpenXLA project, we're splitting XLA outside of TensorFlow.
MHLO belongs to OpenXLA and we're relocating it nested under XLA to allow the
split. Some further directory layout change will likely happen over time.

PiperOrigin-RevId: 464126676
2022-07-29 11:54:51 -07:00
Yash Katariya
2178f339bd Add OpShardingSharding and add a function that can calculate indices from an opsharding proto
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
lenamartens
1ace5d351b Checkify: support checkify-of-pjit. 2022-07-29 19:25:22 +01:00
jax authors
cc19c94d36 Merge pull request #11665 from LenaMartens:docs
PiperOrigin-RevId: 464113785
2022-07-29 10:56:58 -07:00
lenamartens
0256301617 tweak checkify docs. 2022-07-29 18:15:36 +01:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
jax authors
ff97e12edf Merge pull request #11658 from froystig:tree-flatten-order
PiperOrigin-RevId: 464052722
2022-07-29 04:53:58 -07:00
jax authors
25538c8260 Merge pull request #11662 from sharadmv:debugging-docs
PiperOrigin-RevId: 464052649
2022-07-29 04:53:44 -07:00
jax authors
80eb641f6a Internal change
PiperOrigin-RevId: 464052046
2022-07-29 04:47:14 -07:00
Sharad Vikram
fb0cf668b8 Update debugging docs to mention pjit 2022-07-28 22:00:27 -07:00
jax authors
7b17ee87a9 Merge pull request #11660 from mattjj:custom-linear-solve-new-remat
PiperOrigin-RevId: 463998355
2022-07-28 21:57:55 -07:00
Matthew Johnson
aa043a60b6 add test for custom_linear_solve + new remat 2022-07-28 21:37:23 -07:00
jax authors
05a17cfb85 Merge pull request #11659 from mattjj:pmap-new-remat
PiperOrigin-RevId: 463995248
2022-07-28 21:27:59 -07:00
Matthew Johnson
e0c1e6c2ff add custom-policy partial eval and dce rules for pmap
Also add a failing test for xmap.
2022-07-28 21:13:25 -07:00
Yash Katariya
47623264db Export HloSharding via pybind which is a C++ wrapper around OpSharding proto.
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
jax authors
560c936a46 Merge pull request #11653 from sharadmv:debugging-docs
PiperOrigin-RevId: 463988525
2022-07-28 20:26:25 -07:00
Sharad Vikram
4386a0f909 Add debugging tools under jax.debug and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00
Roy Frostig
8677d99267 promise to flatten trees in left-to-right order 2022-07-28 19:28:20 -07:00
jax authors
1c9783ba5d Merge pull request #11657 from froystig:notebook-link
PiperOrigin-RevId: 463981188
2022-07-28 19:26:11 -07:00
Roy Frostig
cd77debfa7 fix broken link in "NNs with TFDS" notebook 2022-07-28 19:04:08 -07:00
jax authors
a636bd3468 Merge pull request #11656 from mattjj:while-loop-new-remat
PiperOrigin-RevId: 463973829
2022-07-28 18:34:07 -07:00
Matthew Johnson
7f3aa12142 add while_loop custom-policy partial eval rule 2022-07-28 18:04:49 -07:00
jax authors
22bc53580e Merge pull request #11651 from mattjj:cond-new-remat
PiperOrigin-RevId: 463948694
2022-07-28 16:05:45 -07:00
Matthew Johnson
ec9f9c3c07 add cond dce rule and custom-policy partial eval rule 2022-07-28 15:50:47 -07:00
Lena Martens
8ca5ecc7f3 Re-land #11498 after internal fixes.
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00
jax authors
697c9b1736 Merge pull request #11641 from mattjj:cond-partial-eval-effects
PiperOrigin-RevId: 463870869
2022-07-28 10:39:48 -07:00
jax authors
97a9b12790 Turn on coordination service by default for all JAX users.
Coordination service is the new implementation of JAX's distributed service. The API remains the same, and eventually will be expanded for newer features such as error reporting.

PiperOrigin-RevId: 463870794
2022-07-28 10:34:07 -07:00
Matthew Johnson
f56ce8a01c update cond partial eval to so eqn effects match branch jaxprs'
Also add some new tests, including some skipped ones, for how effects should
interact with jax.linearize (I think...).
2022-07-28 10:01:56 -07:00
Peter Hawkins
9e6254e058 Increase shard counts for TPU tests in an attempt to fix CI timeouts under asan.
PiperOrigin-RevId: 463830139
2022-07-28 07:14:36 -07:00
Andreas Steiner
0fa3f0346c Makes flax.optim import future proof.
PiperOrigin-RevId: 463820883
2022-07-28 06:15:08 -07:00
jax authors
6eb1cef91f Merge pull request #11635 from pschuh:function-cache
PiperOrigin-RevId: 463761819
2022-07-27 23:20:12 -07:00
Yash Katariya
f4637c364d Fix the gda_xla_sharding_match benchmark which was regressing. This was happening because that function was executed from top to bottom a couple of times and each time a new mesh object was created violating the already created cache which doesn't happen in real life.
```
gda_xla_sharding_match_(256, 8)_PartitionSpec('x', 'y')     21.8ms ± 2%              1.3ms ± 2%  -93.80%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(None,)        21.8ms ± 4%              1.3ms ± 1%  -93.92%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('x',)         21.8ms ± 3%              1.3ms ± 1%  -94.11%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('y',)         21.8ms ± 3%              1.3ms ± 0%  -94.12%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(('x', 'y'),)  21.8ms ± 3%              1.3ms ± 1%  -94.07%          (p=0.008 n=5+5)
gda_xla_sharding_match_(128, 8)_PartitionSpec('x', 'y')     13.9ms ± 6%              1.3ms ± 1%  -90.85%          (p=0.008 n=5+5)
gda_xla_sharding_match_(4, 2)_PartitionSpec('x', 'y')       5.72ms ±10%             1.25ms ± 1%  -78.15%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec('x', 'y')      6.17ms ±11%             1.25ms ± 1%  -79.71%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec(('x', 'y'),)   6.17ms ±10%             1.26ms ± 2%  -79.61%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 463760534
2022-07-27 23:08:55 -07:00
jax authors
2592dfeae7 Merge pull request #11639 from sharadmv:webpdb
PiperOrigin-RevId: 463752326
2022-07-27 21:58:07 -07:00
Xin Zhou
889187f9c6 [mhlo] Add result type inference for mhlo.clamp.
PiperOrigin-RevId: 463743558
2022-07-27 20:39:43 -07:00
Yash Katariya
0dbf492cec Log the key used for barrier
PiperOrigin-RevId: 463741926
2022-07-27 20:25:40 -07:00
Sharad Vikram
547d021157 Enable compatibility with older versions of web_pdb 2022-07-27 18:21:44 -07:00
jax authors
27655af6b9 Merge pull request #11634 from mattjj:fastpath-for-shaped-abstractify
PiperOrigin-RevId: 463718000
2022-07-27 17:33:58 -07:00
jax authors
d5fdd9e266 Merge pull request #11636 from sharadmv:webpdb
PiperOrigin-RevId: 463692257
2022-07-27 15:21:35 -07:00