Jake VanderPlas
91dbcbf525
Remove deprecated jax.experimental.stax
...
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
jax authors
41eff98de3
Merge pull request #11708 from sharadmv:debugger
...
PiperOrigin-RevId: 464859269
2022-08-02 12:31:27 -07:00
Sharad Vikram
9a989573fc
Fix debugger scope issue
2022-08-02 10:38:34 -07:00
jax authors
6079eccfa3
Merge pull request #11673 from jenkspt:add-dtype-args
...
PiperOrigin-RevId: 464809155
2022-08-02 09:20:27 -07:00
jax authors
bad9cd57f0
Merge pull request #11702 from jakevdp:remove-dedupe
...
PiperOrigin-RevId: 464795937
2022-08-02 08:19:59 -07:00
Yash Katariya
1dd35f831c
Add a multihost test for Array on non-continuous mesh
...
PiperOrigin-RevId: 464659865
2022-08-01 17:14:10 -07:00
jax authors
01819257f6
Merge pull request #11701 from sharadmv:state
...
PiperOrigin-RevId: 464658336
2022-08-01 17:05:43 -07:00
Jake VanderPlas
3cce03554f
[sparse] remove deprecated _dedupe() method
2022-08-01 16:59:25 -07:00
Parker Schuh
6b1610ce9e
Add dep on protobuf and build protobufs if protoc is available.
...
PiperOrigin-RevId: 464645042
2022-08-01 16:01:02 -07:00
Penn
1987ca7389
Add dtype arg to jnp.concatenate and update tests
2022-08-01 15:48:40 -07:00
Sharad Vikram
8b7daa8095
Refactor state out of for_loop
2022-08-01 15:26:55 -07:00
jax authors
75d69725c3
Merge pull request #11640 from pschuh:pmap-shaped-array
...
PiperOrigin-RevId: 464623040
2022-08-01 14:22:41 -07:00
jax authors
37ea302f4b
Merge pull request #11696 from jakevdp:autodoc-typehints
...
PiperOrigin-RevId: 464602087
2022-08-01 12:56:41 -07:00
Jake VanderPlas
37c249fc80
MAINT: unpin autodoc-typehints
2022-08-01 10:33:50 -07:00
jax authors
be8939771c
Merge pull request #11675 from mattjj:new-remat-caching
...
PiperOrigin-RevId: 464543287
2022-08-01 08:38:28 -07:00
jax authors
9bd6fd66c2
Merge pull request #11691 from hawkinsp:xla
...
PiperOrigin-RevId: 464541736
2022-08-01 08:28:57 -07:00
Peter Hawkins
96cb47df32
Update XLA.
2022-08-01 12:10:27 +00:00
Sharad Vikram
c08b4ee6d9
Add jaxlib guards for debugging_primitives_test
...
PiperOrigin-RevId: 464453175
2022-07-31 21:23:22 -07:00
Yash Katariya
5112006903
Repr the op_sharding for printing.
...
PiperOrigin-RevId: 464304033
2022-07-30 15:25:26 -07:00
Sharad Vikram
11b206a18a
Enable debugging primitives in pjit
on CPU/GPU
...
PiperOrigin-RevId: 464208326
2022-07-29 20:10:27 -07:00
Matthew Johnson
cbcfe95e80
fix ad_checkpoint.checkpoint caching issue
...
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
Yash Katariya
2109c6ec8c
Make is_compatible_aval
an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
...
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
jax authors
6cffa720e7
Merge pull request #11670 from sharadmv:debugging-docs
...
PiperOrigin-RevId: 464145832
2022-07-29 13:26:56 -07:00
Sharad Vikram
decdca60c8
Change jaxdb->jdb and add option to force a backend
2022-07-29 12:51:27 -07:00
jax authors
21f632740c
Merge pull request #11669 from LenaMartens:check-of-pjit
...
PiperOrigin-RevId: 464133301
2022-07-29 12:25:32 -07:00
Mehdi Amini
5a6cb438e8
Move MHLO to XLA
...
As part of the OpenXLA project, we're splitting XLA outside of TensorFlow.
MHLO belongs to OpenXLA and we're relocating it nested under XLA to allow the
split. Some further directory layout change will likely happen over time.
PiperOrigin-RevId: 464126676
2022-07-29 11:54:51 -07:00
Yash Katariya
2178f339bd
Add OpShardingSharding
and add a function that can calculate indices from an opsharding proto
...
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
lenamartens
1ace5d351b
Checkify: support checkify-of-pjit.
2022-07-29 19:25:22 +01:00
jax authors
cc19c94d36
Merge pull request #11665 from LenaMartens:docs
...
PiperOrigin-RevId: 464113785
2022-07-29 10:56:58 -07:00
lenamartens
0256301617
tweak checkify docs.
2022-07-29 18:15:36 +01:00
Yash Katariya
9a5af235da
Delete sharded_jit
...
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
jax authors
ff97e12edf
Merge pull request #11658 from froystig:tree-flatten-order
...
PiperOrigin-RevId: 464052722
2022-07-29 04:53:58 -07:00
jax authors
25538c8260
Merge pull request #11662 from sharadmv:debugging-docs
...
PiperOrigin-RevId: 464052649
2022-07-29 04:53:44 -07:00
jax authors
80eb641f6a
Internal change
...
PiperOrigin-RevId: 464052046
2022-07-29 04:47:14 -07:00
Sharad Vikram
fb0cf668b8
Update debugging docs to mention pjit
2022-07-28 22:00:27 -07:00
jax authors
7b17ee87a9
Merge pull request #11660 from mattjj:custom-linear-solve-new-remat
...
PiperOrigin-RevId: 463998355
2022-07-28 21:57:55 -07:00
Matthew Johnson
aa043a60b6
add test for custom_linear_solve + new remat
2022-07-28 21:37:23 -07:00
jax authors
05a17cfb85
Merge pull request #11659 from mattjj:pmap-new-remat
...
PiperOrigin-RevId: 463995248
2022-07-28 21:27:59 -07:00
Matthew Johnson
e0c1e6c2ff
add custom-policy partial eval and dce rules for pmap
...
Also add a failing test for xmap.
2022-07-28 21:13:25 -07:00
Yash Katariya
47623264db
Export HloSharding
via pybind which is a C++ wrapper around OpSharding proto.
...
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
jax authors
560c936a46
Merge pull request #11653 from sharadmv:debugging-docs
...
PiperOrigin-RevId: 463988525
2022-07-28 20:26:25 -07:00
Sharad Vikram
4386a0f909
Add debugging tools under jax.debug
and documentation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00
Roy Frostig
8677d99267
promise to flatten trees in left-to-right order
2022-07-28 19:28:20 -07:00
jax authors
1c9783ba5d
Merge pull request #11657 from froystig:notebook-link
...
PiperOrigin-RevId: 463981188
2022-07-28 19:26:11 -07:00
Roy Frostig
cd77debfa7
fix broken link in "NNs with TFDS" notebook
2022-07-28 19:04:08 -07:00
jax authors
a636bd3468
Merge pull request #11656 from mattjj:while-loop-new-remat
...
PiperOrigin-RevId: 463973829
2022-07-28 18:34:07 -07:00
Matthew Johnson
7f3aa12142
add while_loop custom-policy partial eval rule
2022-07-28 18:04:49 -07:00
jax authors
22bc53580e
Merge pull request #11651 from mattjj:cond-new-remat
...
PiperOrigin-RevId: 463948694
2022-07-28 16:05:45 -07:00
Matthew Johnson
ec9f9c3c07
add cond dce rule and custom-policy partial eval rule
2022-07-28 15:50:47 -07:00
Lena Martens
8ca5ecc7f3
Re-land #11498 after internal fixes.
...
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module
PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00