12059 Commits

Author SHA1 Message Date
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
jaxlib-v0.3.14 jax-v0.3.14
2022-06-27 13:56:32 -07:00
Yash Katariya
5b865ed557 Add the assume_metadata option to avoid waiting on ts.open which is very slow on large models. Instead await on the future in the background thread.
PiperOrigin-RevId: 457540435
2022-06-27 12:30:23 -07:00
jax authors
a2f1aee7f3 Merge pull request #11248 from jakevdp:remove-jaxtestcase
PiperOrigin-RevId: 457524591
2022-06-27 11:18:57 -07:00
Jake VanderPlas
887abbc3b9 jax.test_util: remove deprecated test classes.
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
2022-06-27 11:04:50 -07:00
jax authors
997beb3ce0 Merge pull request #11273 from hawkinsp:release
PiperOrigin-RevId: 457466335
2022-06-27 06:50:28 -07:00
Peter Hawkins
1e29b7b762 Update CHANGELOG.md and setup.py for 0.3.14 release. 2022-06-27 09:38:41 -04:00
jax authors
02603606e7 Merge pull request #11244 from hawkinsp:xla
PiperOrigin-RevId: 457461421
2022-06-27 06:20:41 -07:00
Peter Hawkins
f4ddd3ef88 Update XLA. 2022-06-27 09:14:28 -04:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
jax authors
93f5113c93 Merge pull request #11250 from LenaMartens:changelist/456788427
PiperOrigin-RevId: 457430831
2022-06-27 03:00:23 -07:00
Lena Martens
740fe6926a Checkify: add (checkify-of-)vmap-of-check. 2022-06-27 10:34:26 +01:00
jax authors
406a61cf52 Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
jax authors
62c16da81f Merge pull request #11255 from ikmckenz:fix-broken-links-design-notes
PiperOrigin-RevId: 457101918
2022-06-24 15:08:10 -07:00
Shiva Shahrokhi
df8c6263de Change JAX_PLATFORMS to raise an exception when platform initialization fails 2022-06-24 21:54:53 +00:00
jax authors
4ad0234d85 Merge pull request #11251 from hawkinsp:scipy
PiperOrigin-RevId: 457070748
2022-06-24 12:41:51 -07:00
Ian McKenzie
0cc2ada432 Fix broken links for moved design_notes folder 2022-06-24 12:18:11 -07:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Yash Katariya
989a3304bf Fix the creation of pmap sharding spec when sharded_dim is None.
PiperOrigin-RevId: 457045869
2022-06-24 10:46:35 -07:00
Yash Katariya
e32373c3ea Make jnp.array return jax.Array. Add input and result handlers for jax.Array. Also added tests for add under jit.
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
jax authors
53286a9312 Merge pull request #11247 from hawkinsp:maxwell
PiperOrigin-RevId: 457024985
2022-06-24 09:13:57 -07:00
Peter Hawkins
fc659d5308 Reduce size of double-sided maxwell random test.
It appears that for some inputs this triggers an integer overflow in scipy.stats.maxwell().cdf.
2022-06-24 12:01:20 -04:00
Marc van Zee
4c25ef1d00 Simplifies inverting permutation.
PiperOrigin-RevId: 457013218
2022-06-24 08:06:21 -07:00
jax authors
a90bde2c54 Merge pull request #11231 from hawkinsp:remotetpu
PiperOrigin-RevId: 457005076
2022-06-24 07:13:16 -07:00
jax authors
932f77e3d5 Merge pull request #11226 from gnecula:tf_bug_fix
PiperOrigin-RevId: 456932679
2022-06-23 22:21:06 -07:00
Matthew Johnson
8c5632123b fix ad_util.Zero handling in broadcast_in_dim_jvp_rule
PiperOrigin-RevId: 456922766
2022-06-23 20:54:21 -07:00
jax authors
c9c258ea9b Merge pull request #11215 from jakevdp:roots-jit
PiperOrigin-RevId: 456880017
2022-06-23 15:57:54 -07:00
Matthew Johnson
5f97dc8954 Roll forward with simple fix: handle Zero cotangents in _broadcast_in_dim
transpose rule (previously handled by the deflinear2 wrapper, which it's no
longer using).

PiperOrigin-RevId: 456874635
2022-06-23 15:30:22 -07:00
Jake VanderPlas
f6476f7a03 jnp.roots: better support for computation under JIT 2022-06-23 14:48:53 -07:00
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
jax authors
2744404809 Merge pull request #11230 from jakevdp:fix-numpy-123
PiperOrigin-RevId: 456857412
2022-06-23 14:09:35 -07:00
Jake VanderPlas
617df70135 Unpin numpy to ensure most recent version is tested 2022-06-23 12:23:14 -07:00
Jake VanderPlas
eec1225d74 TST: skip tests on numpy 1.23.0 due to regressions in that release 2022-06-23 11:46:51 -07:00
Jake VanderPlas
e92e23e5f8 Use equality rather than identity when checking for float0
Why? This is required due to changes to dtype canonicalization in numpy v1.23; see #11221
2022-06-23 11:46:20 -07:00
jax authors
e4d1e1beb3 Copybara import of the project:
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:

[dynamic-shapes] basic jvp working, including with broadcast

PiperOrigin-RevId: 456822732
2022-06-23 11:32:30 -07:00
jax authors
3737d160b5 Merge pull request #11229 from LenaMartens:changelist/456788425
PiperOrigin-RevId: 456803263
2022-06-23 10:25:30 -07:00
jax authors
a9275d1a25 Merge pull request #11156 from mattjj:djax-ad-jvp
PiperOrigin-RevId: 456797426
2022-06-23 10:02:32 -07:00
Lena Martens
8efeb3e297 Fix getting aval of BatchTracers that are not mapped. 2022-06-23 17:28:45 +01:00
George Necula
391aaf4177 [jax2tf] Fix the documentation for handling dimension polynomials. 2022-06-23 16:51:22 +03:00
jax authors
77a4528bcf Merge pull request #11173 from gnecula:large_prng
PiperOrigin-RevId: 456752688
2022-06-23 06:27:42 -07:00
Kuangyuan Chen
dc1c519547 Reduce jax.jit dispatch overhead by avoiding directly comparing python objects
Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state.

PiperOrigin-RevId: 456667829
2022-06-22 20:04:40 -07:00
Yash Katariya
1908da33af Only initialize GPU backends if they are not already initialized
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Yash Katariya
b623ed58b0 Add a Multiprocess gpu test to test the distributed.initialize() function.
PiperOrigin-RevId: 456633768
2022-06-22 16:20:47 -07:00
Qiao Zhang
be71989af6 Remove broken image link.
PiperOrigin-RevId: 456628284
2022-06-22 15:54:36 -07:00
jax authors
3711e5f71a Merge pull request #10840 from jakevdp:strict-promotion-default
PiperOrigin-RevId: 456586131
2022-06-22 12:54:20 -07:00
jax authors
86d8a467ba Merge pull request #11186 from jakevdp:x64-promotion-error
PiperOrigin-RevId: 456584041
2022-06-22 12:45:27 -07:00
Yash Katariya
766c5ba0a2 Check sharding in pmap for jax.Array.
The checks are:

(1) Check if the in_axes given to pmap matches the sharding of Array.

(2) Check if devices in `array.sharding` is equal to the devices provided to pmap

(3) Check if devices for all array inputs are the same.

(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).

PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Jake VanderPlas
6439435478 Set jax_numpy_dtype_promotion='strict' in tests 2022-06-22 11:22:09 -07:00
jax authors
6a22f586f9 Merge pull request #11207 from jakevdp:x64-scipy-optimize-test
PiperOrigin-RevId: 456560836
2022-06-22 11:21:03 -07:00
Jake VanderPlas
85660f5363 [x64] make scipy_optimize_test compatible with strict dtype promotion 2022-06-22 11:04:20 -07:00
Ruoxin Sang
0a14a81704 Fix mismatched parentheses in jax2tf code examples.
PiperOrigin-RevId: 456531544
2022-06-22 09:33:18 -07:00