11145 Commits

Author SHA1 Message Date
Peter Hawkins
7ffdac0746 Skip remote_transfer_test on cloud TPU.
The necessary API isn't yet implemented for Cloud TPU.

PiperOrigin-RevId: 442058546
2022-04-15 11:24:52 -07:00
jax authors
a4b8a443be Merge pull request #10288 from YouJiacheng:patch-7
PiperOrigin-RevId: 442043193
2022-04-15 10:19:08 -07:00
YouJiacheng
4ff6b1fbca Fix PRNGKeyArray.broadcast_to with scalar shape 2022-04-16 00:27:30 +08:00
jax authors
470f58c9cd Merge pull request #10309 from hawkinsp:jaxlib
PiperOrigin-RevId: 442030934
2022-04-15 09:19:36 -07:00
Peter Hawkins
52a97f2e06 Jax 0.3.7 and jaxlib 0.3.7 release. 2022-04-15 12:02:05 -04:00
Peter Hawkins
0c1021ad4b Temporarily disable integer index check in jnp.take_along_axis.
This check broke some JAX users; disable it to give time to fix them.

PiperOrigin-RevId: 441993808
2022-04-15 05:45:18 -07:00
jax authors
375777f43c Merge pull request #9569 from GJBoth:tree_flatten_docs
PiperOrigin-RevId: 441878577
2022-04-14 16:09:47 -07:00
jax authors
0443f5ed9a Merge pull request #10216 from lgeiger:slice-none
PiperOrigin-RevId: 441877962
2022-04-14 16:04:48 -07:00
jax authors
b290b6eaa3 Merge pull request #10294 from jakevdp:fix-gpu-test
PiperOrigin-RevId: 441855902
2022-04-14 14:31:07 -07:00
jax authors
f8d7969392 Merge pull request #10286 from hawkinsp:takealong
PiperOrigin-RevId: 441853933
2022-04-14 14:26:24 -07:00
jax authors
21ec079c64 Merge pull request #10290 from jakevdp:utils-comment
PiperOrigin-RevId: 441853771
2022-04-14 14:21:30 -07:00
Jake VanderPlas
cadc8046d5 [sparse] gate gpu lowering test to gpu backends 2022-04-14 14:11:11 -07:00
Peter Hawkins
c2fe97ae01 Improve precision of chlo.sinh.
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh

[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.

PiperOrigin-RevId: 441851170
2022-04-14 14:10:26 -07:00
Jake VanderPlas
72470dee3a Comment on implementation of unzip2 & unzip3 2022-04-14 13:41:05 -07:00
jax authors
78fb120f86 Merge pull request #10289 from hawkinsp:docs5
PiperOrigin-RevId: 441841485
2022-04-14 13:33:07 -07:00
Peter Hawkins
c8ac813ec1 Avoid broadcasting the input and indices in jnp.take_along_axis.
In #1521 we added broadcasting to fix an apparent wrong-gradient bug. This
worked, but the real issue was that we were mishandling the case where the
array dimension is of size 1 but the index dimension is not. In that case we
in essence gathered a bunch of out of bounds indices, leading to apparently
incorrect gradients.

The previous fix (broadcasting) worked, but was suboptimal in terms of
performance (#10281). However, we can fix both bugs by removing the broadcasting
and handling the missing case correctly.

Fixes #10281.
2022-04-14 16:21:56 -04:00
Peter Hawkins
3f1032cf33 Fix incorrect cross-reference breaking readthedocs build. 2022-04-14 16:12:48 -04:00
Yilei Yang
7ad1120da0 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 441831488
2022-04-14 12:52:51 -07:00
George Necula
d050327592 Deprecate jax.experimental.loops, step 2.
Add deprecation warning and remove the tests.

PiperOrigin-RevId: 441828243
2022-04-14 12:38:55 -07:00
jax authors
9cf2476c40 Merge pull request #10188 from sharadmv:jax2tf-name-stack
PiperOrigin-RevId: 441819649
2022-04-14 12:00:56 -07:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
6c1461b52b [MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
4806c29bf7 [MHLO] Add MHLO lowerings for FFT ops.
PiperOrigin-RevId: 441768017
2022-04-14 08:31:17 -07:00
Peter Hawkins
b3a62cd3f2 Disable remote_transfer_test on GPU. It currently crashes.
PiperOrigin-RevId: 441762941
2022-04-14 08:06:27 -07:00
Jonathan Heek
df20bd2de5 Expose CopyToRemoteDevice and MakeCrossHostReceiveBuffer in Python bindings.
PiperOrigin-RevId: 441746248
2022-04-14 06:40:48 -07:00
Peter Hawkins
8980bc4bd9 [MHLO] Add MHLO lowerings for name_p and unreachable_p.
PiperOrigin-RevId: 441746096
2022-04-14 06:35:46 -07:00
Lena Martens
e187428a54 Restructure checkify files.
PiperOrigin-RevId: 441726310
2022-04-14 04:32:24 -07:00
Peter Hawkins
665df8dfaa [MHLO] Add an MHLO lowering for rng_bit_generator.
PiperOrigin-RevId: 441628987
2022-04-13 18:09:36 -07:00
jax authors
58f8e93d72 Merge pull request #10268 from jakevdp:polydiv-kokoro
PiperOrigin-RevId: 441569627
2022-04-13 13:49:08 -07:00
Jake VanderPlas
34e206a89e Fix polydiv kokoro tests 2022-04-13 13:21:29 -07:00
jax authors
553924f774 Merge pull request #10265 from hawkinsp:conda
PiperOrigin-RevId: 441554415
2022-04-13 12:50:21 -07:00
jax authors
191c83816c Merge pull request #10226 from ljjsalt:add-polydiv
PiperOrigin-RevId: 441548874
2022-04-13 12:27:22 -07:00
Peter Hawkins
21f95d531b Remove use of xla.lower_fun in SVD translation rule.
This is the only use of xla.lower_fun that is still needed (as a fallback) when the non-MHLO path is removed.

PiperOrigin-RevId: 441538472
2022-04-13 11:44:45 -07:00
Sharad Vikram
1b60e353a2 Enable context manager name stack with jax2tf 2022-04-13 11:38:21 -07:00
Peter Hawkins
2a68e7c975 Add libstdc++ workaround for conda users. 2022-04-13 14:38:09 -04:00
jax authors
eb4307178c Merge pull request #10264 from hawkinsp:six
PiperOrigin-RevId: 441535896
2022-04-13 11:35:25 -07:00
Jiajie Li
128e51c638 Add polydiv to jax.numpy
Fix code style, fix tests

Add warning when use polydiv with trim_leading_zeros

Update warning for polydiv

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Enable type check in _CompileAndCheck

Fix cutoff

Fix cut-off in polydiv

Add trim_zeros_tol, remove redundant code in polydiv

Remove unused import

Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
jax authors
0b898ea627 Merge pull request #10251 from mattjj:debug-nans-error
PiperOrigin-RevId: 441530769
2022-04-13 11:17:23 -07:00
Peter Hawkins
af733ee0c0 Remove mention of six package from jaxlib build instructions.
I verified that six is no longer needed.
2022-04-13 14:09:01 -04:00
jax authors
e5f19138d6 Merge pull request #10262 from jakevdp:while-loop-error
PiperOrigin-RevId: 441527861
2022-04-13 11:08:25 -07:00
jax authors
23f1ef6ad3 Merge pull request #10263 from hawkinsp:minver
PiperOrigin-RevId: 441526817
2022-04-13 11:03:11 -07:00
jax authors
9a3d7c90e3 Merge pull request #10254 from mattjj:default-args-for-config-context-managers
PiperOrigin-RevId: 441521198
2022-04-13 10:43:12 -07:00
Jake VanderPlas
1a8c57d272 better errors: check for callability of lax.control_flow arguments 2022-04-13 10:39:01 -07:00
jax authors
e8ae9d4dbb Merge pull request #10220 from YouJiacheng:Fix#10219
PiperOrigin-RevId: 441515789
2022-04-13 10:34:32 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
jax authors
f72aff56e4 Merge pull request #10248 from YouJiacheng:Fix#10247
PiperOrigin-RevId: 441515740
2022-04-13 10:24:06 -07:00
jax authors
a7ec509465 Merge pull request #10260 from jakevdp:schur-gpu
PiperOrigin-RevId: 441515694
2022-04-13 10:23:51 -07:00
Yash Katariya
eda5bbb514 Expose the input and output sharding on the compiled object.
PiperOrigin-RevId: 441514572
2022-04-13 10:18:25 -07:00
Jake VanderPlas
7bfc86e17f Fix arguments to schur translation rule 2022-04-13 09:50:33 -07:00