12482 Commits

Author SHA1 Message Date
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
jax authors
41eff98de3 Merge pull request #11708 from sharadmv:debugger
PiperOrigin-RevId: 464859269
2022-08-02 12:31:27 -07:00
Sharad Vikram
9a989573fc Fix debugger scope issue 2022-08-02 10:38:34 -07:00
jax authors
6079eccfa3 Merge pull request #11673 from jenkspt:add-dtype-args
PiperOrigin-RevId: 464809155
2022-08-02 09:20:27 -07:00
jax authors
bad9cd57f0 Merge pull request #11702 from jakevdp:remove-dedupe
PiperOrigin-RevId: 464795937
2022-08-02 08:19:59 -07:00
Yash Katariya
1dd35f831c Add a multihost test for Array on non-continuous mesh
PiperOrigin-RevId: 464659865
2022-08-01 17:14:10 -07:00
jax authors
01819257f6 Merge pull request #11701 from sharadmv:state
PiperOrigin-RevId: 464658336
2022-08-01 17:05:43 -07:00
Jake VanderPlas
3cce03554f [sparse] remove deprecated _dedupe() method 2022-08-01 16:59:25 -07:00
Parker Schuh
6b1610ce9e Add dep on protobuf and build protobufs if protoc is available.
PiperOrigin-RevId: 464645042
2022-08-01 16:01:02 -07:00
Penn
1987ca7389 Add dtype arg to jnp.concatenate and update tests 2022-08-01 15:48:40 -07:00
Sharad Vikram
8b7daa8095 Refactor state out of for_loop 2022-08-01 15:26:55 -07:00
jax authors
75d69725c3 Merge pull request #11640 from pschuh:pmap-shaped-array
PiperOrigin-RevId: 464623040
2022-08-01 14:22:41 -07:00
jax authors
37ea302f4b Merge pull request #11696 from jakevdp:autodoc-typehints
PiperOrigin-RevId: 464602087
2022-08-01 12:56:41 -07:00
Jake VanderPlas
37c249fc80 MAINT: unpin autodoc-typehints 2022-08-01 10:33:50 -07:00
jax authors
be8939771c Merge pull request #11675 from mattjj:new-remat-caching
PiperOrigin-RevId: 464543287
2022-08-01 08:38:28 -07:00
jax authors
9bd6fd66c2 Merge pull request #11691 from hawkinsp:xla
PiperOrigin-RevId: 464541736
2022-08-01 08:28:57 -07:00
Peter Hawkins
96cb47df32 Update XLA. 2022-08-01 12:10:27 +00:00
Sharad Vikram
c08b4ee6d9 Add jaxlib guards for debugging_primitives_test
PiperOrigin-RevId: 464453175
2022-07-31 21:23:22 -07:00
Yash Katariya
5112006903 Repr the op_sharding for printing.
PiperOrigin-RevId: 464304033
2022-07-30 15:25:26 -07:00
Sharad Vikram
11b206a18a Enable debugging primitives in pjit on CPU/GPU
PiperOrigin-RevId: 464208326
2022-07-29 20:10:27 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
Yash Katariya
2109c6ec8c Make is_compatible_aval an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
jax authors
6cffa720e7 Merge pull request #11670 from sharadmv:debugging-docs
PiperOrigin-RevId: 464145832
2022-07-29 13:26:56 -07:00
Sharad Vikram
decdca60c8 Change jaxdb->jdb and add option to force a backend 2022-07-29 12:51:27 -07:00
jax authors
21f632740c Merge pull request #11669 from LenaMartens:check-of-pjit
PiperOrigin-RevId: 464133301
2022-07-29 12:25:32 -07:00
Mehdi Amini
5a6cb438e8 Move MHLO to XLA
As part of the OpenXLA project, we're splitting XLA outside of TensorFlow.
MHLO belongs to OpenXLA and we're relocating it nested under XLA to allow the
split. Some further directory layout change will likely happen over time.

PiperOrigin-RevId: 464126676
2022-07-29 11:54:51 -07:00
Yash Katariya
2178f339bd Add OpShardingSharding and add a function that can calculate indices from an opsharding proto
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
lenamartens
1ace5d351b Checkify: support checkify-of-pjit. 2022-07-29 19:25:22 +01:00
jax authors
cc19c94d36 Merge pull request #11665 from LenaMartens:docs
PiperOrigin-RevId: 464113785
2022-07-29 10:56:58 -07:00
lenamartens
0256301617 tweak checkify docs. 2022-07-29 18:15:36 +01:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
jax authors
ff97e12edf Merge pull request #11658 from froystig:tree-flatten-order
PiperOrigin-RevId: 464052722
2022-07-29 04:53:58 -07:00
jax authors
25538c8260 Merge pull request #11662 from sharadmv:debugging-docs
PiperOrigin-RevId: 464052649
2022-07-29 04:53:44 -07:00
jax authors
80eb641f6a Internal change
PiperOrigin-RevId: 464052046
2022-07-29 04:47:14 -07:00
Sharad Vikram
fb0cf668b8 Update debugging docs to mention pjit 2022-07-28 22:00:27 -07:00
jax authors
7b17ee87a9 Merge pull request #11660 from mattjj:custom-linear-solve-new-remat
PiperOrigin-RevId: 463998355
2022-07-28 21:57:55 -07:00
Matthew Johnson
aa043a60b6 add test for custom_linear_solve + new remat 2022-07-28 21:37:23 -07:00
jax authors
05a17cfb85 Merge pull request #11659 from mattjj:pmap-new-remat
PiperOrigin-RevId: 463995248
2022-07-28 21:27:59 -07:00
Matthew Johnson
e0c1e6c2ff add custom-policy partial eval and dce rules for pmap
Also add a failing test for xmap.
2022-07-28 21:13:25 -07:00
Yash Katariya
47623264db Export HloSharding via pybind which is a C++ wrapper around OpSharding proto.
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
jax authors
560c936a46 Merge pull request #11653 from sharadmv:debugging-docs
PiperOrigin-RevId: 463988525
2022-07-28 20:26:25 -07:00
Sharad Vikram
4386a0f909 Add debugging tools under jax.debug and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00
Roy Frostig
8677d99267 promise to flatten trees in left-to-right order 2022-07-28 19:28:20 -07:00
jax authors
1c9783ba5d Merge pull request #11657 from froystig:notebook-link
PiperOrigin-RevId: 463981188
2022-07-28 19:26:11 -07:00
Roy Frostig
cd77debfa7 fix broken link in "NNs with TFDS" notebook 2022-07-28 19:04:08 -07:00
jax authors
a636bd3468 Merge pull request #11656 from mattjj:while-loop-new-remat
PiperOrigin-RevId: 463973829
2022-07-28 18:34:07 -07:00
Matthew Johnson
7f3aa12142 add while_loop custom-policy partial eval rule 2022-07-28 18:04:49 -07:00
jax authors
22bc53580e Merge pull request #11651 from mattjj:cond-new-remat
PiperOrigin-RevId: 463948694
2022-07-28 16:05:45 -07:00
Matthew Johnson
ec9f9c3c07 add cond dce rule and custom-policy partial eval rule 2022-07-28 15:50:47 -07:00
Lena Martens
8ca5ecc7f3 Re-land #11498 after internal fixes.
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00