12390 Commits

Author SHA1 Message Date
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
Skye Wanderman-Milne
d840f54fe5 Bump version numbers after 0.3.15 release
PiperOrigin-RevId: 463344160
2022-07-26 08:39:17 -07:00
jax authors
eb0f952ce1 Merge pull request #11613 from jakevdp:deadcode
PiperOrigin-RevId: 463312795
2022-07-26 05:34:17 -07:00
jax authors
c4b255b527 Merge pull request #11580 from jakevdp:fix-dynamic-index
PiperOrigin-RevId: 463311212
2022-07-26 05:28:47 -07:00
jax authors
4198abeb05 Merge pull request #11619 from gnecula:jax2tf_sharding
PiperOrigin-RevId: 463310860
2022-07-26 05:22:56 -07:00
jax authors
8a7b634b83 Merge pull request #11607 from gnecula:remove_loops
PiperOrigin-RevId: 463310793
2022-07-26 05:17:06 -07:00
George Necula
a4f312d9c3 [loops] Remove jax.experimental.loops.
Has been deprecated since April 2022.
See issue #10278 for an alternative API.
2022-07-26 13:55:57 +03:00
George Necula
afa8f5acb4 Remove jax.experimental.loops. See CHANGELOG
PiperOrigin-RevId: 463297399
2022-07-26 03:39:47 -07:00
George Necula
5b0b8ac58a [jax2tf] Improved documentation and tests for pjit 2022-07-26 12:50:32 +03:00
Yash Katariya
48edd212fc Add on_commit_callback to put the responsibility of renaming the directories on the users of the serialization library. This will also fix the GCS atomic rename issue where the users can write a success file when the commit is successful and check the existence of that file before deserialization.
PiperOrigin-RevId: 463238200
2022-07-25 20:12:55 -07:00
Lena Martens
48a2abcb72 Fix linear_jvp for multiple_results primitives with Zero tangents.
PiperOrigin-RevId: 463190431
2022-07-25 15:26:57 -07:00
Jake VanderPlas
98fac62897 remove dead code: jax._src.util.taggedtuple 2022-07-25 15:14:25 -07:00
jax authors
4fd700293e Merge pull request #11581 from jakevdp:fix-iscomplexobj
PiperOrigin-RevId: 463184295
2022-07-25 14:56:01 -07:00
jax authors
bb1544ddb1 Merge pull request #11611 from jakevdp:warning-filters
PiperOrigin-RevId: 463181057
2022-07-25 14:42:18 -07:00
jax authors
dcf28e71f4 Merge pull request #11610 from jakevdp:0315-changelog
PiperOrigin-RevId: 463175798
2022-07-25 14:19:52 -07:00
Yash Katariya
b42c84f26f Add a opsharding equality function until HLOSharding class is exported via pybind. The equality behavior is the same as HloSharding.
PiperOrigin-RevId: 463162918
2022-07-25 13:24:33 -07:00
Yash Katariya
ea1593a9b2 Make the _check_shapes_against_resources check general for all XLACompatibleShardings by looking at the opsharding proto of the shardings.
PiperOrigin-RevId: 463161459
2022-07-25 13:18:18 -07:00
jax authors
ec435c7e2b Merge pull request #11601 from gnecula:deprecate_mask1
PiperOrigin-RevId: 463138258
2022-07-25 11:39:13 -07:00
Jake VanderPlas
a40fb76a51 pytest: remove obsolete warning filters 2022-07-25 10:47:06 -07:00
Jake VanderPlas
bc90743603 Update changelog for jax/jaxlib v0.3.15 release 2022-07-25 09:47:44 -07:00
jax authors
45498ba4a1 Merge pull request #11591 from mbrukman:fix-jax2tf-readme
PiperOrigin-RevId: 463093055
2022-07-25 08:38:31 -07:00
George Necula
2fd46d13cd Delete the masking.py 2022-07-25 11:25:29 +03:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
jax authors
f5f650fc1c Merge pull request #11593 from sharadmv:debug-jvp
PiperOrigin-RevId: 462863615
2022-07-23 17:16:40 -07:00
jax authors
30d9ab24d7 Merge pull request #11590 from jakevdp:pillow-dep
PiperOrigin-RevId: 462722703
2022-07-22 15:59:24 -07:00
Sharad Vikram
fc1fa134c8 Adjust debug_callback JVP rule to only call on primals 2022-07-22 15:47:23 -07:00
jax authors
c26ae8fc8e Merge pull request #11592 from IvyZX:IvyZX-patch-1
PiperOrigin-RevId: 462714736
2022-07-22 15:18:18 -07:00
Ivy Zheng
dd2716911f
Merge branch 'google:main' into IvyZX-patch-1 2022-07-22 14:58:54 -07:00
Misha Brukman
f04ef8167d Improve text and code formatting in jax2tf docs
* add missing `python` code marker to get syntax highlighting
* fix code formatting by replacing double-backtick with single backtick for
  inline code formatting
* add missing close parenthesis in `tf.function(...)` sample code

Whitespace changes:

* add blank lines between text and code blocks for readability
* add blank lines to separate Python functions and `with` blocks from following
  code to improve code readability and clarify intent
* decrease indentation in code blocks to be flush-left for consistency
2022-07-22 17:40:38 -04:00
Jake VanderPlas
c4169a0c76 make tests compatible with recent pillow versions 2022-07-22 13:09:52 -07:00
jax authors
1a7c8831a8 Merge pull request #11589 from skye:workspace
PiperOrigin-RevId: 462669951
jax-v0.3.15 jaxlib-v0.3.15 jax-v0.3.15-rc
2022-07-22 11:46:22 -07:00
Skye Wanderman-Milne
26fbeb6e2a Update WORKSPACE and libtpu version for jaxlib 0.3.15, take 3 2022-07-22 11:41:39 -07:00
jax authors
e121e811ab Merge pull request #11536 from sharadmv:colab-debugger
PiperOrigin-RevId: 462665740
2022-07-22 11:28:02 -07:00
jax authors
0b6657e471 Merge pull request #11556 from RuffaloLavoisier:tYpO
PiperOrigin-RevId: 462648717
2022-07-22 10:13:10 -07:00
Sharad Vikram
4870710891 Enable debugging callbacks with pjit on TPU
PiperOrigin-RevId: 462527181
2022-07-21 20:22:14 -07:00
Jake VanderPlas
4a693400b9 BUG: make jnp.iscomplexobj compatible with jit 2022-07-21 16:56:29 -07:00
jax authors
8a67734e7b Merge pull request #11579 from sharadmv:fix-effects
PiperOrigin-RevId: 462478510
2022-07-21 15:02:46 -07:00
jax authors
7f0b9179f2 Merge pull request #11575 from gnecula:ds_progress
PiperOrigin-RevId: 462475336
2022-07-21 14:48:24 -07:00
jax authors
24134ec2a5 Merge pull request #11425 from pschuh:pjit-bugfix
PiperOrigin-RevId: 462469178
2022-07-21 14:20:00 -07:00
jax authors
540ee56ff2 Merge pull request #11576 from jakevdp:searchsorted-alt
PiperOrigin-RevId: 462461853
2022-07-21 13:47:43 -07:00
Jake VanderPlas
88b0d198ec dynamic_slice: correctly handle negative start indices in autodiff 2022-07-21 13:41:00 -07:00
Sharad Vikram
d6c172d53e Fix PE not allowing double JIT-ted effectful functions 2022-07-21 11:55:48 -07:00
jax authors
f6c168276b Merge pull request #11578 from jakevdp:wraps-mod
PiperOrigin-RevId: 462437654
2022-07-21 11:50:47 -07:00
Jake VanderPlas
9769a0accf DOC: ensure that _wraps() generates correct links to wrapped functions 2022-07-21 11:12:35 -07:00
jax authors
a4e754849e Merge pull request #11543 from nvcastet:fix_multigpu_test
PiperOrigin-RevId: 462418103
2022-07-21 10:27:57 -07:00
jax authors
1e05a1cfbc Merge pull request #10816 from mattjj:remove-old-pjit-comment
PiperOrigin-RevId: 462411602
2022-07-21 10:01:57 -07:00
George Necula
6c9d2a0b54 [jax2tf] Raise errors for experimental_native_lowering and custom_call
Raise explicit error when the experimental_native_lowering encounters
a mhlo.custom_call. This would lead to failure when trying to run in TF.
2022-07-21 19:58:05 +03:00
Jake VanderPlas
10411bfeae jnp.searchsorted: add optional method argument to control implementation 2022-07-21 09:40:18 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00