Matthew Johnson
148173630f
add an optional fastpath for api_util.shaped_abstractify
...
also add a benchmark for it, 8.7ms -> 0.2ms on my machine
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Yash Katariya
48edd212fc
Add on_commit_callback
to put the responsibility of renaming the directories on the users of the serialization library. This will also fix the GCS atomic rename issue where the users can write a success file when the commit is successful and check the existence of that file before deserialization.
...
PiperOrigin-RevId: 463238200
2022-07-25 20:12:55 -07:00
Lena Martens
48a2abcb72
Fix linear_jvp for multiple_results primitives with Zero tangents.
...
PiperOrigin-RevId: 463190431
2022-07-25 15:26:57 -07:00
jax authors
4fd700293e
Merge pull request #11581 from jakevdp:fix-iscomplexobj
...
PiperOrigin-RevId: 463184295
2022-07-25 14:56:01 -07:00
jax authors
bb1544ddb1
Merge pull request #11611 from jakevdp:warning-filters
...
PiperOrigin-RevId: 463181057
2022-07-25 14:42:18 -07:00
jax authors
dcf28e71f4
Merge pull request #11610 from jakevdp:0315-changelog
...
PiperOrigin-RevId: 463175798
2022-07-25 14:19:52 -07:00
Yash Katariya
b42c84f26f
Add a opsharding equality function until HLOSharding class is exported via pybind. The equality behavior is the same as HloSharding.
...
PiperOrigin-RevId: 463162918
2022-07-25 13:24:33 -07:00
Yash Katariya
ea1593a9b2
Make the _check_shapes_against_resources
check general for all XLACompatibleSharding
s by looking at the opsharding proto of the shardings.
...
PiperOrigin-RevId: 463161459
2022-07-25 13:18:18 -07:00
jax authors
ec435c7e2b
Merge pull request #11601 from gnecula:deprecate_mask1
...
PiperOrigin-RevId: 463138258
2022-07-25 11:39:13 -07:00
Jake VanderPlas
a40fb76a51
pytest: remove obsolete warning filters
2022-07-25 10:47:06 -07:00
Jake VanderPlas
bc90743603
Update changelog for jax/jaxlib v0.3.15 release
2022-07-25 09:47:44 -07:00
jax authors
45498ba4a1
Merge pull request #11591 from mbrukman:fix-jax2tf-readme
...
PiperOrigin-RevId: 463093055
2022-07-25 08:38:31 -07:00
George Necula
2fd46d13cd
Delete the masking.py
2022-07-25 11:25:29 +03:00
George Necula
ab7d036271
Remove dependencies on masking.py
2022-07-25 11:25:26 +03:00
George Necula
66dc95e2de
removes the jax.mask and jax.shapecheck APIs.
...
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
jax authors
f5f650fc1c
Merge pull request #11593 from sharadmv:debug-jvp
...
PiperOrigin-RevId: 462863615
2022-07-23 17:16:40 -07:00
jax authors
30d9ab24d7
Merge pull request #11590 from jakevdp:pillow-dep
...
PiperOrigin-RevId: 462722703
2022-07-22 15:59:24 -07:00
Sharad Vikram
fc1fa134c8
Adjust debug_callback JVP rule to only call on primals
2022-07-22 15:47:23 -07:00
jax authors
c26ae8fc8e
Merge pull request #11592 from IvyZX:IvyZX-patch-1
...
PiperOrigin-RevId: 462714736
2022-07-22 15:18:18 -07:00
Ivy Zheng
dd2716911f
Merge branch 'google:main' into IvyZX-patch-1
2022-07-22 14:58:54 -07:00
Misha Brukman
f04ef8167d
Improve text and code formatting in jax2tf docs
...
* add missing `python` code marker to get syntax highlighting
* fix code formatting by replacing double-backtick with single backtick for
inline code formatting
* add missing close parenthesis in `tf.function(...)` sample code
Whitespace changes:
* add blank lines between text and code blocks for readability
* add blank lines to separate Python functions and `with` blocks from following
code to improve code readability and clarify intent
* decrease indentation in code blocks to be flush-left for consistency
2022-07-22 17:40:38 -04:00
Jake VanderPlas
c4169a0c76
make tests compatible with recent pillow versions
2022-07-22 13:09:52 -07:00
jax authors
1a7c8831a8
Merge pull request #11589 from skye:workspace
...
PiperOrigin-RevId: 462669951
jax-v0.3.15
jaxlib-v0.3.15
jax-v0.3.15-rc
2022-07-22 11:46:22 -07:00
Skye Wanderman-Milne
26fbeb6e2a
Update WORKSPACE and libtpu version for jaxlib 0.3.15, take 3
2022-07-22 11:41:39 -07:00
jax authors
e121e811ab
Merge pull request #11536 from sharadmv:colab-debugger
...
PiperOrigin-RevId: 462665740
2022-07-22 11:28:02 -07:00
jax authors
0b6657e471
Merge pull request #11556 from RuffaloLavoisier:tYpO
...
PiperOrigin-RevId: 462648717
2022-07-22 10:13:10 -07:00
Sharad Vikram
4870710891
Enable debugging callbacks with pjit on TPU
...
PiperOrigin-RevId: 462527181
2022-07-21 20:22:14 -07:00
Jake VanderPlas
4a693400b9
BUG: make jnp.iscomplexobj compatible with jit
2022-07-21 16:56:29 -07:00
jax authors
8a67734e7b
Merge pull request #11579 from sharadmv:fix-effects
...
PiperOrigin-RevId: 462478510
2022-07-21 15:02:46 -07:00
jax authors
7f0b9179f2
Merge pull request #11575 from gnecula:ds_progress
...
PiperOrigin-RevId: 462475336
2022-07-21 14:48:24 -07:00
jax authors
24134ec2a5
Merge pull request #11425 from pschuh:pjit-bugfix
...
PiperOrigin-RevId: 462469178
2022-07-21 14:20:00 -07:00
jax authors
540ee56ff2
Merge pull request #11576 from jakevdp:searchsorted-alt
...
PiperOrigin-RevId: 462461853
2022-07-21 13:47:43 -07:00
Sharad Vikram
d6c172d53e
Fix PE not allowing double JIT-ted effectful functions
2022-07-21 11:55:48 -07:00
jax authors
f6c168276b
Merge pull request #11578 from jakevdp:wraps-mod
...
PiperOrigin-RevId: 462437654
2022-07-21 11:50:47 -07:00
Jake VanderPlas
9769a0accf
DOC: ensure that _wraps() generates correct links to wrapped functions
2022-07-21 11:12:35 -07:00
jax authors
a4e754849e
Merge pull request #11543 from nvcastet:fix_multigpu_test
...
PiperOrigin-RevId: 462418103
2022-07-21 10:27:57 -07:00
jax authors
1e05a1cfbc
Merge pull request #10816 from mattjj:remove-old-pjit-comment
...
PiperOrigin-RevId: 462411602
2022-07-21 10:01:57 -07:00
George Necula
6c9d2a0b54
[jax2tf] Raise errors for experimental_native_lowering and custom_call
...
Raise explicit error when the experimental_native_lowering encounters
a mhlo.custom_call. This would lead to failure when trying to run in TF.
2022-07-21 19:58:05 +03:00
Jake VanderPlas
10411bfeae
jnp.searchsorted: add optional method argument to control implementation
2022-07-21 09:40:18 -07:00
George Necula
07fcf79324
jax.mask and jax.shapecheck are being deprecated
...
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
jax authors
ba7ded4331
Merge pull request #11571 from google:skye-patch-2
...
PiperOrigin-RevId: 462267465
2022-07-20 17:36:20 -07:00
Skye Wanderman-Milne
568cedba8d
Update WORKSPACE for 0.3.15 release, take 2
2022-07-20 17:23:52 -07:00
jax authors
be6db2e619
Merge pull request #10775 from pschuh:mlir-caching
...
PiperOrigin-RevId: 462263487
2022-07-20 17:10:40 -07:00
Parker Schuh
6c4da65af4
Add treedef_is_strict_leaf to fix _prefix_error's semantics.
...
Empty nodes like [] and {} have 1 node and 0 leaves. This does not make
them a leaf treedef.
Reproducer:
```
pjit.pjit(lambda x: x, None, (None, {}))((3, {'a': []}))
```
2022-07-20 17:02:59 -07:00
Kuangyuan Chen
c0ec3b33e6
Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
...
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Yash Katariya
d8cbb29d14
OpSharding doesn't have __eq__
defined on it. Don't check sharding equality using opsharding until it does support that.
...
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Parker Schuh
d8f0099f68
_mlirTransforms merged into _mlirRegisterEverything.
...
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Yash Katariya
ad67d825fe
Add a faster __eq__ check for Mesh. When the id
of self and other is the same, there is no need to compare the devices which can be slow when there are 1000s of devices.
...
PiperOrigin-RevId: 462230016
2022-07-20 14:25:41 -07:00
Yash Katariya
026636951a
Add lru_cache
and use it instead of util.cache()
in places where tracing user code is not required.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 462212010
2022-07-20 13:05:20 -07:00
jax authors
ffe67c1042
Merge pull request #11563 from jakevdp:upstream-ci
...
PiperOrigin-RevId: 462200130
2022-07-20 12:09:08 -07:00