Peter Hawkins
5b576cb03e
Revert: Drop flatbuffers as a Python dependency of JAX.
...
This change appears to be causing crashes on Mac.
Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457559793
jaxlib-v0.3.14
jax-v0.3.14
2022-06-27 13:56:32 -07:00
Yash Katariya
5b865ed557
Add the assume_metadata
option to avoid waiting on ts.open
which is very slow on large models. Instead await on the future in the background thread.
...
PiperOrigin-RevId: 457540435
2022-06-27 12:30:23 -07:00
jax authors
a2f1aee7f3
Merge pull request #11248 from jakevdp:remove-jaxtestcase
...
PiperOrigin-RevId: 457524591
2022-06-27 11:18:57 -07:00
Jake VanderPlas
887abbc3b9
jax.test_util: remove deprecated test classes.
...
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
2022-06-27 11:04:50 -07:00
jax authors
997beb3ce0
Merge pull request #11273 from hawkinsp:release
...
PiperOrigin-RevId: 457466335
2022-06-27 06:50:28 -07:00
Peter Hawkins
1e29b7b762
Update CHANGELOG.md and setup.py for 0.3.14 release.
2022-06-27 09:38:41 -04:00
jax authors
02603606e7
Merge pull request #11244 from hawkinsp:xla
...
PiperOrigin-RevId: 457461421
2022-06-27 06:20:41 -07:00
Peter Hawkins
f4ddd3ef88
Update XLA.
2022-06-27 09:14:28 -04:00
Peter Hawkins
efefeac450
Drop flatbuffers as a Python dependency of JAX.
...
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
jax authors
93f5113c93
Merge pull request #11250 from LenaMartens:changelist/456788427
...
PiperOrigin-RevId: 457430831
2022-06-27 03:00:23 -07:00
Lena Martens
740fe6926a
Checkify: add (checkify-of-)vmap-of-check.
2022-06-27 10:34:26 +01:00
jax authors
406a61cf52
Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
...
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
jax authors
62c16da81f
Merge pull request #11255 from ikmckenz:fix-broken-links-design-notes
...
PiperOrigin-RevId: 457101918
2022-06-24 15:08:10 -07:00
Shiva Shahrokhi
df8c6263de
Change JAX_PLATFORMS to raise an exception when platform initialization fails
2022-06-24 21:54:53 +00:00
jax authors
4ad0234d85
Merge pull request #11251 from hawkinsp:scipy
...
PiperOrigin-RevId: 457070748
2022-06-24 12:41:51 -07:00
Ian McKenzie
0cc2ada432
Fix broken links for moved design_notes folder
2022-06-24 12:18:11 -07:00
Peter Hawkins
a560a29e12
Increase the minimum scipy version to 1.5.
...
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Yash Katariya
989a3304bf
Fix the creation of pmap sharding spec when sharded_dim is None.
...
PiperOrigin-RevId: 457045869
2022-06-24 10:46:35 -07:00
Yash Katariya
e32373c3ea
Make jnp.array
return jax.Array
. Add input and result handlers for jax.Array
. Also added tests for add
under jit.
...
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
jax authors
53286a9312
Merge pull request #11247 from hawkinsp:maxwell
...
PiperOrigin-RevId: 457024985
2022-06-24 09:13:57 -07:00
Peter Hawkins
fc659d5308
Reduce size of double-sided maxwell random test.
...
It appears that for some inputs this triggers an integer overflow in scipy.stats.maxwell().cdf.
2022-06-24 12:01:20 -04:00
Marc van Zee
4c25ef1d00
Simplifies inverting permutation.
...
PiperOrigin-RevId: 457013218
2022-06-24 08:06:21 -07:00
jax authors
a90bde2c54
Merge pull request #11231 from hawkinsp:remotetpu
...
PiperOrigin-RevId: 457005076
2022-06-24 07:13:16 -07:00
jax authors
932f77e3d5
Merge pull request #11226 from gnecula:tf_bug_fix
...
PiperOrigin-RevId: 456932679
2022-06-23 22:21:06 -07:00
Matthew Johnson
8c5632123b
fix ad_util.Zero handling in broadcast_in_dim_jvp_rule
...
PiperOrigin-RevId: 456922766
2022-06-23 20:54:21 -07:00
jax authors
c9c258ea9b
Merge pull request #11215 from jakevdp:roots-jit
...
PiperOrigin-RevId: 456880017
2022-06-23 15:57:54 -07:00
Matthew Johnson
5f97dc8954
Roll forward with simple fix: handle Zero cotangents in _broadcast_in_dim
...
transpose rule (previously handled by the deflinear2 wrapper, which it's no
longer using).
PiperOrigin-RevId: 456874635
2022-06-23 15:30:22 -07:00
Jake VanderPlas
f6476f7a03
jnp.roots: better support for computation under JIT
2022-06-23 14:48:53 -07:00
Peter Hawkins
22304eeb2e
Add a build flag that allows disabling remote TPU builds.
...
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
jax authors
2744404809
Merge pull request #11230 from jakevdp:fix-numpy-123
...
PiperOrigin-RevId: 456857412
2022-06-23 14:09:35 -07:00
Jake VanderPlas
617df70135
Unpin numpy to ensure most recent version is tested
2022-06-23 12:23:14 -07:00
Jake VanderPlas
eec1225d74
TST: skip tests on numpy 1.23.0 due to regressions in that release
2022-06-23 11:46:51 -07:00
Jake VanderPlas
e92e23e5f8
Use equality rather than identity when checking for float0
...
Why? This is required due to changes to dtype canonicalization in numpy v1.23; see #11221
2022-06-23 11:46:20 -07:00
jax authors
e4d1e1beb3
Copybara import of the project:
...
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:
[dynamic-shapes] basic jvp working, including with broadcast
PiperOrigin-RevId: 456822732
2022-06-23 11:32:30 -07:00
jax authors
3737d160b5
Merge pull request #11229 from LenaMartens:changelist/456788425
...
PiperOrigin-RevId: 456803263
2022-06-23 10:25:30 -07:00
jax authors
a9275d1a25
Merge pull request #11156 from mattjj:djax-ad-jvp
...
PiperOrigin-RevId: 456797426
2022-06-23 10:02:32 -07:00
Lena Martens
8efeb3e297
Fix getting aval of BatchTracers that are not mapped.
2022-06-23 17:28:45 +01:00
George Necula
391aaf4177
[jax2tf] Fix the documentation for handling dimension polynomials.
2022-06-23 16:51:22 +03:00
jax authors
77a4528bcf
Merge pull request #11173 from gnecula:large_prng
...
PiperOrigin-RevId: 456752688
2022-06-23 06:27:42 -07:00
Kuangyuan Chen
dc1c519547
Reduce jax.jit dispatch overhead by avoiding directly comparing python objects
...
Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state.
PiperOrigin-RevId: 456667829
2022-06-22 20:04:40 -07:00
Yash Katariya
1908da33af
Only initialize GPU backends if they are not already initialized
...
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Yash Katariya
b623ed58b0
Add a Multiprocess gpu test to test the distributed.initialize() function.
...
PiperOrigin-RevId: 456633768
2022-06-22 16:20:47 -07:00
Qiao Zhang
be71989af6
Remove broken image link.
...
PiperOrigin-RevId: 456628284
2022-06-22 15:54:36 -07:00
jax authors
3711e5f71a
Merge pull request #10840 from jakevdp:strict-promotion-default
...
PiperOrigin-RevId: 456586131
2022-06-22 12:54:20 -07:00
jax authors
86d8a467ba
Merge pull request #11186 from jakevdp:x64-promotion-error
...
PiperOrigin-RevId: 456584041
2022-06-22 12:45:27 -07:00
Yash Katariya
766c5ba0a2
Check sharding in pmap for jax.Array
.
...
The checks are:
(1) Check if the in_axes given to pmap matches the sharding of Array.
(2) Check if devices in `array.sharding` is equal to the devices provided to pmap
(3) Check if devices for all array inputs are the same.
(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).
PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Jake VanderPlas
6439435478
Set jax_numpy_dtype_promotion='strict' in tests
2022-06-22 11:22:09 -07:00
jax authors
6a22f586f9
Merge pull request #11207 from jakevdp:x64-scipy-optimize-test
...
PiperOrigin-RevId: 456560836
2022-06-22 11:21:03 -07:00
Jake VanderPlas
85660f5363
[x64] make scipy_optimize_test compatible with strict dtype promotion
2022-06-22 11:04:20 -07:00
Ruoxin Sang
0a14a81704
Fix mismatched parentheses in jax2tf code examples.
...
PiperOrigin-RevId: 456531544
2022-06-22 09:33:18 -07:00