11127 Commits

Author SHA1 Message Date
Peter Hawkins
3f1032cf33 Fix incorrect cross-reference breaking readthedocs build. 2022-04-14 16:12:48 -04:00
Yilei Yang
7ad1120da0 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 441831488
2022-04-14 12:52:51 -07:00
George Necula
d050327592 Deprecate jax.experimental.loops, step 2.
Add deprecation warning and remove the tests.

PiperOrigin-RevId: 441828243
2022-04-14 12:38:55 -07:00
jax authors
9cf2476c40 Merge pull request #10188 from sharadmv:jax2tf-name-stack
PiperOrigin-RevId: 441819649
2022-04-14 12:00:56 -07:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
6c1461b52b [MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
4806c29bf7 [MHLO] Add MHLO lowerings for FFT ops.
PiperOrigin-RevId: 441768017
2022-04-14 08:31:17 -07:00
Peter Hawkins
b3a62cd3f2 Disable remote_transfer_test on GPU. It currently crashes.
PiperOrigin-RevId: 441762941
2022-04-14 08:06:27 -07:00
Jonathan Heek
df20bd2de5 Expose CopyToRemoteDevice and MakeCrossHostReceiveBuffer in Python bindings.
PiperOrigin-RevId: 441746248
2022-04-14 06:40:48 -07:00
Peter Hawkins
8980bc4bd9 [MHLO] Add MHLO lowerings for name_p and unreachable_p.
PiperOrigin-RevId: 441746096
2022-04-14 06:35:46 -07:00
Lena Martens
e187428a54 Restructure checkify files.
PiperOrigin-RevId: 441726310
2022-04-14 04:32:24 -07:00
Peter Hawkins
665df8dfaa [MHLO] Add an MHLO lowering for rng_bit_generator.
PiperOrigin-RevId: 441628987
2022-04-13 18:09:36 -07:00
jax authors
58f8e93d72 Merge pull request #10268 from jakevdp:polydiv-kokoro
PiperOrigin-RevId: 441569627
2022-04-13 13:49:08 -07:00
Jake VanderPlas
34e206a89e Fix polydiv kokoro tests 2022-04-13 13:21:29 -07:00
jax authors
553924f774 Merge pull request #10265 from hawkinsp:conda
PiperOrigin-RevId: 441554415
2022-04-13 12:50:21 -07:00
jax authors
191c83816c Merge pull request #10226 from ljjsalt:add-polydiv
PiperOrigin-RevId: 441548874
2022-04-13 12:27:22 -07:00
Peter Hawkins
21f95d531b Remove use of xla.lower_fun in SVD translation rule.
This is the only use of xla.lower_fun that is still needed (as a fallback) when the non-MHLO path is removed.

PiperOrigin-RevId: 441538472
2022-04-13 11:44:45 -07:00
Sharad Vikram
1b60e353a2 Enable context manager name stack with jax2tf 2022-04-13 11:38:21 -07:00
Peter Hawkins
2a68e7c975 Add libstdc++ workaround for conda users. 2022-04-13 14:38:09 -04:00
jax authors
eb4307178c Merge pull request #10264 from hawkinsp:six
PiperOrigin-RevId: 441535896
2022-04-13 11:35:25 -07:00
Jiajie Li
128e51c638 Add polydiv to jax.numpy
Fix code style, fix tests

Add warning when use polydiv with trim_leading_zeros

Update warning for polydiv

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Enable type check in _CompileAndCheck

Fix cutoff

Fix cut-off in polydiv

Add trim_zeros_tol, remove redundant code in polydiv

Remove unused import

Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
jax authors
0b898ea627 Merge pull request #10251 from mattjj:debug-nans-error
PiperOrigin-RevId: 441530769
2022-04-13 11:17:23 -07:00
Peter Hawkins
af733ee0c0 Remove mention of six package from jaxlib build instructions.
I verified that six is no longer needed.
2022-04-13 14:09:01 -04:00
jax authors
e5f19138d6 Merge pull request #10262 from jakevdp:while-loop-error
PiperOrigin-RevId: 441527861
2022-04-13 11:08:25 -07:00
jax authors
23f1ef6ad3 Merge pull request #10263 from hawkinsp:minver
PiperOrigin-RevId: 441526817
2022-04-13 11:03:11 -07:00
jax authors
9a3d7c90e3 Merge pull request #10254 from mattjj:default-args-for-config-context-managers
PiperOrigin-RevId: 441521198
2022-04-13 10:43:12 -07:00
Jake VanderPlas
1a8c57d272 better errors: check for callability of lax.control_flow arguments 2022-04-13 10:39:01 -07:00
jax authors
e8ae9d4dbb Merge pull request #10220 from YouJiacheng:Fix#10219
PiperOrigin-RevId: 441515789
2022-04-13 10:34:32 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
jax authors
f72aff56e4 Merge pull request #10248 from YouJiacheng:Fix#10247
PiperOrigin-RevId: 441515740
2022-04-13 10:24:06 -07:00
jax authors
a7ec509465 Merge pull request #10260 from jakevdp:schur-gpu
PiperOrigin-RevId: 441515694
2022-04-13 10:23:51 -07:00
Yash Katariya
eda5bbb514 Expose the input and output sharding on the compiled object.
PiperOrigin-RevId: 441514572
2022-04-13 10:18:25 -07:00
Jake VanderPlas
7bfc86e17f Fix arguments to schur translation rule 2022-04-13 09:50:33 -07:00
jax authors
86c8446c00 Merge pull request #10229 from hyeontaek:transfer-guard-remove-compat-code
PiperOrigin-RevId: 441490830
2022-04-13 08:45:28 -07:00
Peter Hawkins
cb4abe754a [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Marc van Zee
af89426a73 [jax2tf] Fixes conv bug. The filter dims are allowed to be bigger than the input dims if the padding type is "SAME".
PiperOrigin-RevId: 441395527
2022-04-13 00:07:08 -07:00
jax authors
eb83a4711f Merge pull request #10249 from sharadmv:jaxpr-effects
PiperOrigin-RevId: 441355920
2022-04-12 19:29:01 -07:00
Yash Katariya
08f28a6119 Finish jax release
PiperOrigin-RevId: 441342919
2022-04-12 18:13:16 -07:00
Sharad Vikram
4392b07022 Add tests for higher order primitives 2022-04-12 18:12:44 -07:00
Yash Katariya
6ba9fb699d Upgrade the bazel version to 5.1.1
PiperOrigin-RevId: 441338363
jax-v0.3.6
2022-04-12 17:48:09 -07:00
Matthew Johnson
2a46c5e0d8 add default values to config context managers 2022-04-12 15:05:53 -07:00
Matthew Johnson
8bc8e40e72 debug_nans: don't return results of successfully running de-optimized function 2022-04-12 14:40:19 -07:00
YouJiacheng
dad324d934 Fix#10247 2022-04-13 04:05:55 +08:00
YouJiacheng
4695dd919c Fix#10219 2022-04-13 04:04:11 +08:00
jax authors
c06eff8cd8 Merge pull request #10245 from google:yashk2810-patch-7
PiperOrigin-RevId: 441265709
2022-04-12 12:54:22 -07:00
Yash Katariya
5fd78eaf02
Bump the libtpu version to prepare for JAX release 2022-04-12 11:41:07 -07:00
Peter Hawkins
9455254b9f [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.

PiperOrigin-RevId: 441214076
2022-04-12 10:30:52 -07:00
Yash Katariya
3136004c62 Fix the pytype error. PyType is looking for a __init__ method. This does not change the behavior of the class.
```
Function PartitionSpec.__init__ expects 1 arg(s), got 3 [wrong-arg-count]
         Expected: (self)
  Actually passed: (self, _, _)
```

PiperOrigin-RevId: 441211351
2022-04-12 09:36:28 -07:00