16055 Commits

Author SHA1 Message Date
Yash Katariya
b748db8ef1 Fix global_array_to_host_local_array when the specified pspec and mesh do not match the sharding of the input array.
In that case, reshard the array and then create a host local array from that.

Also improve the shard mismatch error that jax.Array raises.

PiperOrigin-RevId: 531397741
2023-05-11 22:02:58 -07:00
jax authors
f75e86c085 Merge pull request #15979 from skye:version
PiperOrigin-RevId: 531370738
2023-05-11 19:32:52 -07:00
Skye Wanderman-Milne
533a7c05f1 Update versions and changelog post 0.4.10 release 2023-05-11 18:16:02 -07:00
jax authors
21fc6e0229 Merge pull request #15978 from skye:version
PiperOrigin-RevId: 531321316
jax-v0.4.10 jaxlib-v0.4.10 jax-v0.4.10-rc
2023-05-11 15:20:30 -07:00
Skye Wanderman-Milne
82bbeef519 Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release 2023-05-11 14:46:06 -07:00
Parker Schuh
6e8181495a Construct topologies and hook up aot_test for pjrt_c_api.
PiperOrigin-RevId: 531310241
2023-05-11 14:37:04 -07:00
jax authors
d8c487b5c7 Merge pull request #15956 from sharadmv:pure-callback-maximal
PiperOrigin-RevId: 531304370
2023-05-11 14:14:49 -07:00
jax authors
0037ab6240 [PJRT C API] Check whether the PJRT_Api* for the device type already exists before calling dlopen and dlsym.
PiperOrigin-RevId: 531295150
2023-05-11 13:43:17 -07:00
Sharad Vikram
61f22676b0 Add maximal sharding for pure_callback not inside of a shard_map 2023-05-11 13:28:37 -07:00
Yash Katariya
1bef7c9787 Fix McJAX resharding when the input has a fully replicated sharding
PiperOrigin-RevId: 531263333
2023-05-11 11:42:36 -07:00
Parker Schuh
11b34a90fd Skip stream-executor for aot_test.py.
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Parker Schuh
ee20330631 Default None-topology platform to TPU.
PiperOrigin-RevId: 531245559
2023-05-11 10:41:45 -07:00
George Necula
e57794f176 [shape_poly] Fix test breakage.
In cl/530804516 we changed the parser for polymorphic shape
specifications and also changed the error message. This
lead to failure in the TF.js jax_conversion_test.

We improve the error message and adjust the jax_conversion_test
to match the new message.

PiperOrigin-RevId: 531238700
2023-05-11 10:19:16 -07:00
Anish Tondwalkar
fd3216a41f Migrate ApproxTopK fallback to StableHLO
PiperOrigin-RevId: 531222034
2023-05-11 09:17:58 -07:00
Peter Hawkins
9471bb3045 Disable sparsify_test on CPU under tsan.
Under tsan this test times out in CI.

PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
jax authors
6a68750f35 Merge pull request #15958 from mattjj:checkify-closed-call
PiperOrigin-RevId: 531098683
2023-05-10 22:30:10 -07:00
Matthew Johnson
f55de18933 [checkify] fix closed_call_p handling
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
jax authors
74df2d758a Merge pull request #15603 from mattjj:shmap-call-lowering
PiperOrigin-RevId: 530996233
2023-05-10 13:49:51 -07:00
Matthew Johnson
8b66f073d1 [shard-map] experiment with lowering to a Call with attrs
Co-authored-by: Bart Chrzaszcz <bartchr@google.com>
2023-05-10 13:14:04 -07:00
jax authors
b0017a7355 Merge pull request #15955 from jakevdp:grad-opaque
PiperOrigin-RevId: 530981558
2023-05-10 12:51:35 -07:00
Jake VanderPlas
b250c706b0 Allow opaque dtypes in grad with allow_int=True 2023-05-10 11:43:17 -07:00
jax authors
81a5a5ee52 Merge pull request #15936 from gnecula:poly_vmap_tests
PiperOrigin-RevId: 530951808
2023-05-10 10:55:16 -07:00
jax authors
d6828c9c35 Merge pull request #15953 from jakevdp:keyarray-dynamic-slice
PiperOrigin-RevId: 530936580
2023-05-10 09:58:49 -07:00
Jake VanderPlas
6ada8785aa PRNGKeyArray: fix dynamic slice index dtype 2023-05-10 09:24:18 -07:00
jax authors
70f0cc4690 Merge pull request #15944 from mattjj:shmap-remove-cast
PiperOrigin-RevId: 530911060
2023-05-10 08:11:19 -07:00
jax authors
538c680e04 Merge pull request #15943 from mattjj:custom-jvp-checkify-symzeros
PiperOrigin-RevId: 530907814
2023-05-10 07:56:40 -07:00
George Necula
1429dd5be2 [shape_poly] Remove old test limitations
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.

We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
2023-05-10 13:38:24 +02:00
jax authors
48f551378a Merge pull request #15949 from gnecula:fix_poly
PiperOrigin-RevId: 530832168
2023-05-10 00:51:37 -07:00
George Necula
e0518a5154 [shape_poly] Fix shape parsing regression
The changes in #15912 inadvertently have dropped some
error checking for the parsed polymorphic specifications.
2023-05-10 09:32:00 +02:00
Anish Tondwalkar
840461673d Migrate ApproxTopK to StableHLO
This uses an ApproxTopK custom-call, which we add support for in supported by
MHLO, by including a lowering to XLA's PartialReduce custom_call via the Client
XLA ApproxTopK function.

PiperOrigin-RevId: 530805966
2023-05-09 22:31:22 -07:00
jax authors
8aa14337e6 Merge pull request #15912 from gnecula:poly_parse
PiperOrigin-RevId: 530804516
2023-05-09 22:23:29 -07:00
jax authors
bbc96320ed Merge pull request #15947 from skye:version
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658 Add improved TPU SVD accuracy to the changelog.
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Yash Katariya
954cda9ce1 Move lint_and_typecheck and documentation job to the ubuntu-latest image since we don't need a large machine for it
PiperOrigin-RevId: 530734120
2023-05-09 15:47:22 -07:00
jax authors
1b9180167b Merge pull request #15945 from skye:version
PiperOrigin-RevId: 530722158
jaxlib-v0.4.9 jax-v0.4.9 jax-v0.4.9-rc
2023-05-09 14:59:20 -07:00
Skye Wanderman-Milne
5bcd9dcc46 Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release, take 2 2023-05-09 14:49:54 -07:00
Matthew Johnson
0e14075a35 remove cast 2023-05-09 14:44:05 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Matthew Johnson
391e95a683 fix checkify custom_jvp rule to handle symbolic zeros
likely broken in #15426, or maybe not quite right before either

Co-authored-by: Roy Frostig <frostig@google.com>
2023-05-09 14:12:53 -07:00
Yash Katariya
2694bf6207 Use set equality operators instead of intersection because I didn't know set had equality operators.
PiperOrigin-RevId: 530688786
2023-05-09 12:55:47 -07:00
jax authors
68ba54241c Merge pull request #15929 from gnecula:fix_mlir_ir
PiperOrigin-RevId: 530675418
2023-05-09 12:02:35 -07:00
Peter Hawkins
a89c377762 [GPU] Fix another instance of missing stream synchronization in RNN kernels.
PiperOrigin-RevId: 530660502
2023-05-09 11:08:24 -07:00
jax authors
a2b5bd5230 Merge pull request #15931 from geraschenko:bcoo_reshape
PiperOrigin-RevId: 530657565
2023-05-09 10:58:53 -07:00
Anton Geraschenko
27aa5fb774 Make dimensions argument of bcoo_reshape optional. 2023-05-09 10:38:18 -07:00
Yash Katariya
18d19caa1c Add McJAX resharding to device_put. Allow resharding if inputs and target sharding have the same set of devices but different order.
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.

Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
2023-05-09 09:58:12 -07:00
jax authors
cf4c1edafa Merge pull request #15920 from froystig:issue15869
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
George Necula
daf6a30f6e Import "ir" directly rather than as "mlir.ir" 2023-05-09 17:55:13 +02:00
jax authors
cb3a4f3dbf Merge pull request #15859 from gnecula:poly_rng
PiperOrigin-RevId: 530606467
2023-05-09 07:43:31 -07:00