4044 Commits

Author SHA1 Message Date
Peter Hawkins
b3a62cd3f2 Disable remote_transfer_test on GPU. It currently crashes.
PiperOrigin-RevId: 441762941
2022-04-14 08:06:27 -07:00
Jonathan Heek
df20bd2de5 Expose CopyToRemoteDevice and MakeCrossHostReceiveBuffer in Python bindings.
PiperOrigin-RevId: 441746248
2022-04-14 06:40:48 -07:00
Jake VanderPlas
34e206a89e Fix polydiv kokoro tests 2022-04-13 13:21:29 -07:00
jax authors
191c83816c Merge pull request #10226 from ljjsalt:add-polydiv
PiperOrigin-RevId: 441548874
2022-04-13 12:27:22 -07:00
Jiajie Li
128e51c638 Add polydiv to jax.numpy
Fix code style, fix tests

Add warning when use polydiv with trim_leading_zeros

Update warning for polydiv

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Enable type check in _CompileAndCheck

Fix cutoff

Fix cut-off in polydiv

Add trim_zeros_tol, remove redundant code in polydiv

Remove unused import

Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
jax authors
0b898ea627 Merge pull request #10251 from mattjj:debug-nans-error
PiperOrigin-RevId: 441530769
2022-04-13 11:17:23 -07:00
jax authors
e5f19138d6 Merge pull request #10262 from jakevdp:while-loop-error
PiperOrigin-RevId: 441527861
2022-04-13 11:08:25 -07:00
jax authors
23f1ef6ad3 Merge pull request #10263 from hawkinsp:minver
PiperOrigin-RevId: 441526817
2022-04-13 11:03:11 -07:00
Jake VanderPlas
1a8c57d272 better errors: check for callability of lax.control_flow arguments 2022-04-13 10:39:01 -07:00
jax authors
e8ae9d4dbb Merge pull request #10220 from YouJiacheng:Fix#10219
PiperOrigin-RevId: 441515789
2022-04-13 10:34:32 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
Yash Katariya
eda5bbb514 Expose the input and output sharding on the compiled object.
PiperOrigin-RevId: 441514572
2022-04-13 10:18:25 -07:00
jax authors
86c8446c00 Merge pull request #10229 from hyeontaek:transfer-guard-remove-compat-code
PiperOrigin-RevId: 441490830
2022-04-13 08:45:28 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Sharad Vikram
4392b07022 Add tests for higher order primitives 2022-04-12 18:12:44 -07:00
Matthew Johnson
8bc8e40e72 debug_nans: don't return results of successfully running de-optimized function 2022-04-12 14:40:19 -07:00
YouJiacheng
4695dd919c Fix#10219 2022-04-13 04:04:11 +08:00
Peter Hawkins
9455254b9f [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.

PiperOrigin-RevId: 441214076
2022-04-12 10:30:52 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Hyeontaek Lim
36df8619d7 Bump minimum jaxlib version to 0.3.2 and remove transfer guard compatibility code 2022-04-11 15:33:27 +00:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Tianjian Lu
a11b41f581 [sparse] Use sorted indices instead of sorted rows only.
PiperOrigin-RevId: 440579642
2022-04-09 08:33:48 -07:00
Peter Hawkins
94307a02c8 Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
PiperOrigin-RevId: 440452521
2022-04-08 14:22:15 -07:00
Matthew Johnson
272ed95858 remove experimental/djax
PiperOrigin-RevId: 440445082
2022-04-08 13:55:22 -07:00
jax authors
0bfb3efcd7 [JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
Peter Hawkins
0f15fa3b10 [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
PiperOrigin-RevId: 440433044
2022-04-08 12:57:59 -07:00
jax authors
b8602d018b Merge pull request #10198 from jakevdp:bcoo-duplicates
PiperOrigin-RevId: 440423326
2022-04-08 12:12:12 -07:00
Yash Katariya
654e5bd922 Roll forward again after the fix in the auto sharding pass.
PiperOrigin-RevId: 440412218
2022-04-08 11:25:07 -07:00
Jake VanderPlas
8b9efe79e7 [sparse] fix autodiff bug in spdot_general 2022-04-08 11:04:26 -07:00
Peter Hawkins
648a512488 [MHLO] Add direct MHLO lowerings for sparse primitives.
PiperOrigin-RevId: 440374054
2022-04-08 08:43:57 -07:00
Joan Puigcerver
0c02f7935a Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Jake VanderPlas
01e4fa8a78 [sparse] consolidate flavors of bcoo_dot_general 2022-04-07 11:28:12 -07:00
jax authors
8b3f039252 Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
jax authors
7ee6adb1a5 Merge pull request #10173 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Yash Katariya
6a7a34603d Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
This move is because sharded_jit is being deprecated.

PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Jake VanderPlas
93a24f3b83 [sparse] add bcoo_add_batchdim 2022-04-06 14:44:29 -07:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Jake VanderPlas
5a96c0cb18 Skip test outside x64 2022-04-04 16:00:18 -07:00
Peter Hawkins
71a5eb263b [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Yash Katariya
6825f654b1 * Disallow any other type other than GDA and ShapedArray for auto sharding.
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
  * Auto sharding
    * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
  * NO auto sharding
    * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch

PiperOrigin-RevId: 439413895
2022-04-04 14:33:51 -07:00
Peter Hawkins
1b8be90801 Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.

PiperOrigin-RevId: 439324450
2022-04-04 08:40:09 -07:00
jax authors
e64a57d2c3 Merge pull request #10121 from hawkinsp:hcbcache
PiperOrigin-RevId: 439036780
2022-04-02 08:57:24 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
1c3edc811d Merge pull request #10110 from pschuh:weakref-bug
PiperOrigin-RevId: 438887762
2022-04-01 12:45:35 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
Peter Hawkins
208e83ceb7 Avoid retracing when a host_callback.call is called multiple times with the same function.
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.

This issue was found in passing while debugging #9970.
2022-04-01 14:41:14 -04:00
jax authors
e766b96063 Merge pull request #10058 from yotarok:istft
PiperOrigin-RevId: 438832534
2022-04-01 08:43:27 -07:00
jax authors
4decbcb00e Merge pull request #10103 from LenaMartens:changelist/438319917
PiperOrigin-RevId: 438821559
2022-04-01 07:40:45 -07:00
Yotaro Kubo
a7fd751acf Add istft to jax.scipy.signal. 2022-04-01 14:28:53 +09:00