16066 Commits

Author SHA1 Message Date
Peter Hawkins
f7f1ddbb1e Temporarily disable LaxBackedScipyStatsTests.testTruncnormPdf.
This test started failing at LLVM head.

PiperOrigin-RevId: 532095958
2023-05-15 06:52:46 -07:00
Yash Katariya
78f6246c2f Put the value on the first CPU device uncommitted so that we can transfer it to the right device when executing the computation.
PiperOrigin-RevId: 531996001
2023-05-14 22:31:34 -07:00
jax authors
7aefc9a5b4 Merge pull request #15986 from shoyer:vmap-annotations
PiperOrigin-RevId: 531579516
2023-05-12 12:49:47 -07:00
John QiangZhang
2c05fe996e Add a new test to cover multiple calls to same tf function when call_tf_graph = True.
PiperOrigin-RevId: 531578811
2023-05-12 12:42:42 -07:00
Yash Katariya
559b837ba5 Add logging if we get a C++ cache miss
PiperOrigin-RevId: 531555996
2023-05-12 11:19:58 -07:00
jax authors
0bc3136fbc Merge pull request #15987 from mattjj:ann-doctest-fix
PiperOrigin-RevId: 531555328
2023-05-12 11:12:44 -07:00
Matthew Johnson
12d52663a4 fix doctest CI failure fron ann.py vestigial print 2023-05-12 10:46:05 -07:00
Yash Katariya
009af38697 Remove gda_serialization from JAX. It's replacement is array_serialization
PiperOrigin-RevId: 531525875
2023-05-12 09:23:05 -07:00
Stephan Hoyer
3254bf8d9e Fix type annotations on jax.vmap 2023-05-12 09:08:33 -07:00
jax authors
0dd5ef4864 Merge pull request #15985 from hawkinsp:readme
PiperOrigin-RevId: 531514083
2023-05-12 08:33:09 -07:00
Peter Hawkins
46d30ba29e Drop mentions of the CUDA 11.4/CUDNN 8.2 wheel.
This wheel is no longer shipped as part of jaxlib releases; please upgrade if you were using this.
2023-05-12 10:41:48 -04:00
Yash Katariya
b748db8ef1 Fix global_array_to_host_local_array when the specified pspec and mesh do not match the sharding of the input array.
In that case, reshard the array and then create a host local array from that.

Also improve the shard mismatch error that jax.Array raises.

PiperOrigin-RevId: 531397741
2023-05-11 22:02:58 -07:00
jax authors
f75e86c085 Merge pull request #15979 from skye:version
PiperOrigin-RevId: 531370738
2023-05-11 19:32:52 -07:00
Skye Wanderman-Milne
533a7c05f1 Update versions and changelog post 0.4.10 release 2023-05-11 18:16:02 -07:00
jax authors
21fc6e0229 Merge pull request #15978 from skye:version
PiperOrigin-RevId: 531321316
jax-v0.4.10 jaxlib-v0.4.10 jax-v0.4.10-rc
2023-05-11 15:20:30 -07:00
Skye Wanderman-Milne
82bbeef519 Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release 2023-05-11 14:46:06 -07:00
Parker Schuh
6e8181495a Construct topologies and hook up aot_test for pjrt_c_api.
PiperOrigin-RevId: 531310241
2023-05-11 14:37:04 -07:00
jax authors
d8c487b5c7 Merge pull request #15956 from sharadmv:pure-callback-maximal
PiperOrigin-RevId: 531304370
2023-05-11 14:14:49 -07:00
jax authors
0037ab6240 [PJRT C API] Check whether the PJRT_Api* for the device type already exists before calling dlopen and dlsym.
PiperOrigin-RevId: 531295150
2023-05-11 13:43:17 -07:00
Sharad Vikram
61f22676b0 Add maximal sharding for pure_callback not inside of a shard_map 2023-05-11 13:28:37 -07:00
Yash Katariya
1bef7c9787 Fix McJAX resharding when the input has a fully replicated sharding
PiperOrigin-RevId: 531263333
2023-05-11 11:42:36 -07:00
Parker Schuh
11b34a90fd Skip stream-executor for aot_test.py.
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Parker Schuh
ee20330631 Default None-topology platform to TPU.
PiperOrigin-RevId: 531245559
2023-05-11 10:41:45 -07:00
George Necula
e57794f176 [shape_poly] Fix test breakage.
In cl/530804516 we changed the parser for polymorphic shape
specifications and also changed the error message. This
lead to failure in the TF.js jax_conversion_test.

We improve the error message and adjust the jax_conversion_test
to match the new message.

PiperOrigin-RevId: 531238700
2023-05-11 10:19:16 -07:00
Anish Tondwalkar
fd3216a41f Migrate ApproxTopK fallback to StableHLO
PiperOrigin-RevId: 531222034
2023-05-11 09:17:58 -07:00
Peter Hawkins
9471bb3045 Disable sparsify_test on CPU under tsan.
Under tsan this test times out in CI.

PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
jax authors
6a68750f35 Merge pull request #15958 from mattjj:checkify-closed-call
PiperOrigin-RevId: 531098683
2023-05-10 22:30:10 -07:00
Matthew Johnson
f55de18933 [checkify] fix closed_call_p handling
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
jax authors
74df2d758a Merge pull request #15603 from mattjj:shmap-call-lowering
PiperOrigin-RevId: 530996233
2023-05-10 13:49:51 -07:00
Matthew Johnson
8b66f073d1 [shard-map] experiment with lowering to a Call with attrs
Co-authored-by: Bart Chrzaszcz <bartchr@google.com>
2023-05-10 13:14:04 -07:00
jax authors
b0017a7355 Merge pull request #15955 from jakevdp:grad-opaque
PiperOrigin-RevId: 530981558
2023-05-10 12:51:35 -07:00
Jake VanderPlas
b250c706b0 Allow opaque dtypes in grad with allow_int=True 2023-05-10 11:43:17 -07:00
jax authors
81a5a5ee52 Merge pull request #15936 from gnecula:poly_vmap_tests
PiperOrigin-RevId: 530951808
2023-05-10 10:55:16 -07:00
jax authors
d6828c9c35 Merge pull request #15953 from jakevdp:keyarray-dynamic-slice
PiperOrigin-RevId: 530936580
2023-05-10 09:58:49 -07:00
Jake VanderPlas
6ada8785aa PRNGKeyArray: fix dynamic slice index dtype 2023-05-10 09:24:18 -07:00
jax authors
70f0cc4690 Merge pull request #15944 from mattjj:shmap-remove-cast
PiperOrigin-RevId: 530911060
2023-05-10 08:11:19 -07:00
jax authors
538c680e04 Merge pull request #15943 from mattjj:custom-jvp-checkify-symzeros
PiperOrigin-RevId: 530907814
2023-05-10 07:56:40 -07:00
George Necula
1429dd5be2 [shape_poly] Remove old test limitations
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.

We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
2023-05-10 13:38:24 +02:00
jax authors
48f551378a Merge pull request #15949 from gnecula:fix_poly
PiperOrigin-RevId: 530832168
2023-05-10 00:51:37 -07:00
George Necula
e0518a5154 [shape_poly] Fix shape parsing regression
The changes in #15912 inadvertently have dropped some
error checking for the parsed polymorphic specifications.
2023-05-10 09:32:00 +02:00
Anish Tondwalkar
840461673d Migrate ApproxTopK to StableHLO
This uses an ApproxTopK custom-call, which we add support for in supported by
MHLO, by including a lowering to XLA's PartialReduce custom_call via the Client
XLA ApproxTopK function.

PiperOrigin-RevId: 530805966
2023-05-09 22:31:22 -07:00
jax authors
8aa14337e6 Merge pull request #15912 from gnecula:poly_parse
PiperOrigin-RevId: 530804516
2023-05-09 22:23:29 -07:00
jax authors
bbc96320ed Merge pull request #15947 from skye:version
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658 Add improved TPU SVD accuracy to the changelog.
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Yash Katariya
954cda9ce1 Move lint_and_typecheck and documentation job to the ubuntu-latest image since we don't need a large machine for it
PiperOrigin-RevId: 530734120
2023-05-09 15:47:22 -07:00
jax authors
1b9180167b Merge pull request #15945 from skye:version
PiperOrigin-RevId: 530722158
jaxlib-v0.4.9 jax-v0.4.9 jax-v0.4.9-rc
2023-05-09 14:59:20 -07:00
Skye Wanderman-Milne
5bcd9dcc46 Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release, take 2 2023-05-09 14:49:54 -07:00
Matthew Johnson
0e14075a35 remove cast 2023-05-09 14:44:05 -07:00