12232 Commits

Author SHA1 Message Date
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
jax authors
2744404809 Merge pull request #11230 from jakevdp:fix-numpy-123
PiperOrigin-RevId: 456857412
2022-06-23 14:09:35 -07:00
Jake VanderPlas
617df70135 Unpin numpy to ensure most recent version is tested 2022-06-23 12:23:14 -07:00
Jake VanderPlas
eec1225d74 TST: skip tests on numpy 1.23.0 due to regressions in that release 2022-06-23 11:46:51 -07:00
Jake VanderPlas
e92e23e5f8 Use equality rather than identity when checking for float0
Why? This is required due to changes to dtype canonicalization in numpy v1.23; see #11221
2022-06-23 11:46:20 -07:00
jax authors
e4d1e1beb3 Copybara import of the project:
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:

[dynamic-shapes] basic jvp working, including with broadcast

PiperOrigin-RevId: 456822732
2022-06-23 11:32:30 -07:00
jax authors
3737d160b5 Merge pull request #11229 from LenaMartens:changelist/456788425
PiperOrigin-RevId: 456803263
2022-06-23 10:25:30 -07:00
jax authors
a9275d1a25 Merge pull request #11156 from mattjj:djax-ad-jvp
PiperOrigin-RevId: 456797426
2022-06-23 10:02:32 -07:00
Lena Martens
8efeb3e297 Fix getting aval of BatchTracers that are not mapped. 2022-06-23 17:28:45 +01:00
George Necula
391aaf4177 [jax2tf] Fix the documentation for handling dimension polynomials. 2022-06-23 16:51:22 +03:00
jax authors
77a4528bcf Merge pull request #11173 from gnecula:large_prng
PiperOrigin-RevId: 456752688
2022-06-23 06:27:42 -07:00
Kuangyuan Chen
dc1c519547 Reduce jax.jit dispatch overhead by avoiding directly comparing python objects
Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state.

PiperOrigin-RevId: 456667829
2022-06-22 20:04:40 -07:00
Yash Katariya
1908da33af Only initialize GPU backends if they are not already initialized
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Yash Katariya
b623ed58b0 Add a Multiprocess gpu test to test the distributed.initialize() function.
PiperOrigin-RevId: 456633768
2022-06-22 16:20:47 -07:00
Qiao Zhang
be71989af6 Remove broken image link.
PiperOrigin-RevId: 456628284
2022-06-22 15:54:36 -07:00
jax authors
3711e5f71a Merge pull request #10840 from jakevdp:strict-promotion-default
PiperOrigin-RevId: 456586131
2022-06-22 12:54:20 -07:00
jax authors
86d8a467ba Merge pull request #11186 from jakevdp:x64-promotion-error
PiperOrigin-RevId: 456584041
2022-06-22 12:45:27 -07:00
Yash Katariya
766c5ba0a2 Check sharding in pmap for jax.Array.
The checks are:

(1) Check if the in_axes given to pmap matches the sharding of Array.

(2) Check if devices in `array.sharding` is equal to the devices provided to pmap

(3) Check if devices for all array inputs are the same.

(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).

PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Jake VanderPlas
6439435478 Set jax_numpy_dtype_promotion='strict' in tests 2022-06-22 11:22:09 -07:00
jax authors
6a22f586f9 Merge pull request #11207 from jakevdp:x64-scipy-optimize-test
PiperOrigin-RevId: 456560836
2022-06-22 11:21:03 -07:00
Jake VanderPlas
85660f5363 [x64] make scipy_optimize_test compatible with strict dtype promotion 2022-06-22 11:04:20 -07:00
Ruoxin Sang
0a14a81704 Fix mismatched parentheses in jax2tf code examples.
PiperOrigin-RevId: 456531544
2022-06-22 09:33:18 -07:00
Yash Katariya
1b21d2c3f5 Return Array from jax.device_put if config.jax_array is enabled.
PiperOrigin-RevId: 456531510
2022-06-22 09:20:56 -07:00
Yash Katariya
dce8f64b40 Make device_put_sharded and device_put_replicated return Arrays.
PiperOrigin-RevId: 456525113
2022-06-22 08:51:29 -07:00
Yash Katariya
6f3b3ac8f9 Add __repr__ to Shard since its not a dataclass anymore
PiperOrigin-RevId: 456463979
2022-06-22 02:26:00 -07:00
Roy Frostig
621dfb9347 remove unused stages.Lowered._xla_computation
PiperOrigin-RevId: 456412379
2022-06-21 20:04:54 -07:00
Roy Frostig
6a0c05fa20 remove unused stages.Compiled._xla_executable
PiperOrigin-RevId: 456412184
2022-06-21 19:59:16 -07:00
Yash Katariya
e5031d15de Disable xla sharding propagation test for SE because XLA sharding propagation is not supported on SE which is activate when out_axis_resources is not specified in pjit.
PiperOrigin-RevId: 456391444
2022-06-21 17:40:31 -07:00
jax authors
15a3798424 Merge pull request #11147 from jakevdp:x64-scipy-signal-test
PiperOrigin-RevId: 456389803
2022-06-21 17:32:15 -07:00
jax authors
1d6cbc93dd Merge pull request #11192 from sharadmv:release
PiperOrigin-RevId: 456355961
2022-06-21 14:48:41 -07:00
Sharad Vikram
217d898124 Update TF version for jaxlib build 2022-06-21 14:34:22 -07:00
jax authors
c6e3f00eb3 Merge pull request #11189 from jakevdp:fix-jvp-doc
PiperOrigin-RevId: 456341327
2022-06-21 13:46:01 -07:00
jax authors
f85886ba65 Merge pull request #11188 from jakevdp:doc-fix
PiperOrigin-RevId: 456339293
2022-06-21 13:40:32 -07:00
jax authors
2dbf3e0795 Merge pull request #11187 from sharadmv:release
PiperOrigin-RevId: 456338899
2022-06-21 13:34:56 -07:00
Jake VanderPlas
abcfaec6e3 DOC: clarify variable names 2022-06-21 13:20:53 -07:00
Sharad Vikram
9bd1bd67e0 Update versions for jax/jaxlib release 2022-06-21 12:57:28 -07:00
Jake VanderPlas
ed152953ef DOC: fix output repr in thinking_in_jax 2022-06-21 12:54:50 -07:00
jax authors
cd11aeca83 Merge pull request #11138 from jakevdp:x64-scipy-optimize-test
PiperOrigin-RevId: 456327057
2022-06-21 12:47:34 -07:00
jax authors
da46e268b2 Merge pull request #11152 from jakevdp:x64-jet-test
PiperOrigin-RevId: 456325791
2022-06-21 12:42:23 -07:00
jax authors
5669be791c Merge pull request #11184 from jakevdp:update-myst-nb
PiperOrigin-RevId: 456324987
2022-06-21 12:36:33 -07:00
Yash Katariya
bf1b29362a Fix up some comments after pmap+Array CL.
PiperOrigin-RevId: 456310489
2022-06-21 11:32:46 -07:00
Jake VanderPlas
997c90b3bb [x64] mention flag value in strict type promotion error 2022-06-21 10:55:24 -07:00
jax authors
2c34097671 Merge pull request #11183 from tmi:docs-fix-profiler
PiperOrigin-RevId: 456285950
2022-06-21 10:01:36 -07:00
Jake VanderPlas
3a8f478b0a [x64] make scipy_optimize_test compatible with strict dtype promotion 2022-06-21 09:29:08 -07:00
Jake VanderPlas
03f2189f90 [x64] make jax.scipy.signal compatible with strict dtype promotion.
Also a fair bit of cleanup & refactoring of related code.
2022-06-21 09:28:46 -07:00
Jake VanderPlas
893179fcfc [x64] make jet_test compatible with strict dtype promotion 2022-06-21 09:28:24 -07:00
jax authors
d7ec244c0c Merge pull request #11160 from jakevdp:x64-xmap-test
PiperOrigin-RevId: 456274863
2022-06-21 09:13:38 -07:00
Jake VanderPlas
b5b78b1052 CI: build docs with most recent myst-nb version 2022-06-21 08:44:23 -07:00
vojta tuma
4e057c9f23 DOC: update pprof install instructions 2022-06-21 11:15:14 +02:00
jax authors
30365b8bcc Merge pull request #10852 from LenaMartens:changelist/442814308
PiperOrigin-RevId: 456115702
2022-06-20 13:10:12 -07:00