Peter Hawkins
b3a62cd3f2
Disable remote_transfer_test on GPU. It currently crashes.
...
PiperOrigin-RevId: 441762941
2022-04-14 08:06:27 -07:00
Jonathan Heek
df20bd2de5
Expose CopyToRemoteDevice and MakeCrossHostReceiveBuffer in Python bindings.
...
PiperOrigin-RevId: 441746248
2022-04-14 06:40:48 -07:00
Jake VanderPlas
34e206a89e
Fix polydiv kokoro tests
2022-04-13 13:21:29 -07:00
jax authors
191c83816c
Merge pull request #10226 from ljjsalt:add-polydiv
...
PiperOrigin-RevId: 441548874
2022-04-13 12:27:22 -07:00
Jiajie Li
128e51c638
Add polydiv to jax.numpy
...
Fix code style, fix tests
Add warning when use polydiv with trim_leading_zeros
Update warning for polydiv
Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>
Enable type check in _CompileAndCheck
Fix cutoff
Fix cut-off in polydiv
Add trim_zeros_tol, remove redundant code in polydiv
Remove unused import
Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
jax authors
0b898ea627
Merge pull request #10251 from mattjj:debug-nans-error
...
PiperOrigin-RevId: 441530769
2022-04-13 11:17:23 -07:00
jax authors
e5f19138d6
Merge pull request #10262 from jakevdp:while-loop-error
...
PiperOrigin-RevId: 441527861
2022-04-13 11:08:25 -07:00
jax authors
23f1ef6ad3
Merge pull request #10263 from hawkinsp:minver
...
PiperOrigin-RevId: 441526817
2022-04-13 11:03:11 -07:00
Jake VanderPlas
1a8c57d272
better errors: check for callability of lax.control_flow arguments
2022-04-13 10:39:01 -07:00
jax authors
e8ae9d4dbb
Merge pull request #10220 from YouJiacheng:Fix#10219
...
PiperOrigin-RevId: 441515789
2022-04-13 10:34:32 -07:00
Peter Hawkins
94efc90939
Drop dead code now that the minimum jaxlib version is 0.3.2.
2022-04-13 13:34:00 -04:00
Yash Katariya
eda5bbb514
Expose the input and output sharding on the compiled object.
...
PiperOrigin-RevId: 441514572
2022-04-13 10:18:25 -07:00
jax authors
86c8446c00
Merge pull request #10229 from hyeontaek:transfer-guard-remove-compat-code
...
PiperOrigin-RevId: 441490830
2022-04-13 08:45:28 -07:00
Peter Hawkins
ad8e6ada4e
[MHLO] Change jax.xla_computation() to use MHLO lowering internally.
...
Change in preparation for removing the non-MHLO lowering path.
PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Sharad Vikram
4392b07022
Add tests for higher order primitives
2022-04-12 18:12:44 -07:00
Matthew Johnson
8bc8e40e72
debug_nans: don't return results of successfully running de-optimized function
2022-04-12 14:40:19 -07:00
YouJiacheng
4695dd919c
Fix#10219
2022-04-13 04:04:11 +08:00
Peter Hawkins
9455254b9f
[MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
...
This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.
PiperOrigin-RevId: 441214076
2022-04-12 10:30:52 -07:00
Matthew Johnson
4354f355a8
prototyping dynamic shapes
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25
Adds simple effect types to jaxprs
2022-04-11 11:50:41 -07:00
Hyeontaek Lim
36df8619d7
Bump minimum jaxlib version to 0.3.2 and remove transfer guard compatibility code
2022-04-11 15:33:27 +00:00
Matthew Johnson
902fc0c3d2
Remove invertible_ad since it's not in use.
...
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Tianjian Lu
a11b41f581
[sparse] Use sorted indices instead of sorted rows only.
...
PiperOrigin-RevId: 440579642
2022-04-09 08:33:48 -07:00
Peter Hawkins
94307a02c8
Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
...
PiperOrigin-RevId: 440452521
2022-04-08 14:22:15 -07:00
Matthew Johnson
272ed95858
remove experimental/djax
...
PiperOrigin-RevId: 440445082
2022-04-08 13:55:22 -07:00
jax authors
0bfb3efcd7
[JAX] Fix batch logic for approx_min/max_k
...
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.
PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
Peter Hawkins
0f15fa3b10
[MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
...
PiperOrigin-RevId: 440433044
2022-04-08 12:57:59 -07:00
jax authors
b8602d018b
Merge pull request #10198 from jakevdp:bcoo-duplicates
...
PiperOrigin-RevId: 440423326
2022-04-08 12:12:12 -07:00
Yash Katariya
654e5bd922
Roll forward again after the fix in the auto sharding pass.
...
PiperOrigin-RevId: 440412218
2022-04-08 11:25:07 -07:00
Jake VanderPlas
8b9efe79e7
[sparse] fix autodiff bug in spdot_general
2022-04-08 11:04:26 -07:00
Peter Hawkins
648a512488
[MHLO] Add direct MHLO lowerings for sparse primitives.
...
PiperOrigin-RevId: 440374054
2022-04-08 08:43:57 -07:00
Joan Puigcerver
0c02f7935a
Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
...
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Jake VanderPlas
01e4fa8a78
[sparse] consolidate flavors of bcoo_dot_general
2022-04-07 11:28:12 -07:00
jax authors
8b3f039252
Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
...
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
jax authors
7ee6adb1a5
Merge pull request #10173 from jakevdp:bcoo-add-batchdim
...
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Yash Katariya
6a7a34603d
Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
...
This move is because sharded_jit is being deprecated.
PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Jake VanderPlas
93a24f3b83
[sparse] add bcoo_add_batchdim
2022-04-06 14:44:29 -07:00
Alex Riley
869596fc2c
Add jax.scipy.linalg.rsf2csf
2022-04-06 21:06:23 +01:00
Jake VanderPlas
5a96c0cb18
Skip test outside x64
2022-04-04 16:00:18 -07:00
Peter Hawkins
71a5eb263b
[GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
...
Fixes https://github.com/google/jax/issues/9946
PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Yash Katariya
6825f654b1
* Disallow any other type other than GDA and ShapedArray for auto sharding.
...
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
* Auto sharding
* f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
* compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
* NO auto sharding
* f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
* compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
PiperOrigin-RevId: 439413895
2022-04-04 14:33:51 -07:00
Peter Hawkins
1b8be90801
Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
...
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.
PiperOrigin-RevId: 439324450
2022-04-04 08:40:09 -07:00
jax authors
e64a57d2c3
Merge pull request #10121 from hawkinsp:hcbcache
...
PiperOrigin-RevId: 439036780
2022-04-02 08:57:24 -07:00
Jake VanderPlas
df1ceaeeb1
Deprecate jax.tree_util.tree_multimap
2022-04-01 14:51:54 -07:00
jax authors
1c3edc811d
Merge pull request #10110 from pschuh:weakref-bug
...
PiperOrigin-RevId: 438887762
2022-04-01 12:45:35 -07:00
Parker Schuh
df1c478ec5
Fix race condition for weakref destructor by catching rare exceptions.
2022-04-01 12:04:36 -07:00
Peter Hawkins
208e83ceb7
Avoid retracing when a host_callback.call is called multiple times with the same function.
...
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.
This issue was found in passing while debugging #9970 .
2022-04-01 14:41:14 -04:00
jax authors
e766b96063
Merge pull request #10058 from yotarok:istft
...
PiperOrigin-RevId: 438832534
2022-04-01 08:43:27 -07:00
jax authors
4decbcb00e
Merge pull request #10103 from LenaMartens:changelist/438319917
...
PiperOrigin-RevId: 438821559
2022-04-01 07:40:45 -07:00
Yotaro Kubo
a7fd751acf
Add istft to jax.scipy.signal.
2022-04-01 14:28:53 +09:00