Yash Katariya
b748db8ef1
Fix global_array_to_host_local_array when the specified pspec and mesh do not match the sharding of the input array.
...
In that case, reshard the array and then create a host local array from that.
Also improve the shard mismatch error that jax.Array raises.
PiperOrigin-RevId: 531397741
2023-05-11 22:02:58 -07:00
jax authors
f75e86c085
Merge pull request #15979 from skye:version
...
PiperOrigin-RevId: 531370738
2023-05-11 19:32:52 -07:00
Skye Wanderman-Milne
533a7c05f1
Update versions and changelog post 0.4.10 release
2023-05-11 18:16:02 -07:00
jax authors
21fc6e0229
Merge pull request #15978 from skye:version
...
PiperOrigin-RevId: 531321316
jax-v0.4.10
jaxlib-v0.4.10
jax-v0.4.10-rc
2023-05-11 15:20:30 -07:00
Skye Wanderman-Milne
82bbeef519
Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release
2023-05-11 14:46:06 -07:00
Parker Schuh
6e8181495a
Construct topologies and hook up aot_test for pjrt_c_api.
...
PiperOrigin-RevId: 531310241
2023-05-11 14:37:04 -07:00
jax authors
d8c487b5c7
Merge pull request #15956 from sharadmv:pure-callback-maximal
...
PiperOrigin-RevId: 531304370
2023-05-11 14:14:49 -07:00
jax authors
0037ab6240
[PJRT C API] Check whether the PJRT_Api* for the device type already exists before calling dlopen and dlsym.
...
PiperOrigin-RevId: 531295150
2023-05-11 13:43:17 -07:00
Sharad Vikram
61f22676b0
Add maximal sharding for pure_callback not inside of a shard_map
2023-05-11 13:28:37 -07:00
Yash Katariya
1bef7c9787
Fix McJAX resharding when the input has a fully replicated sharding
...
PiperOrigin-RevId: 531263333
2023-05-11 11:42:36 -07:00
Parker Schuh
11b34a90fd
Skip stream-executor for aot_test.py.
...
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Parker Schuh
ee20330631
Default None-topology platform to TPU.
...
PiperOrigin-RevId: 531245559
2023-05-11 10:41:45 -07:00
George Necula
e57794f176
[shape_poly] Fix test breakage.
...
In cl/530804516 we changed the parser for polymorphic shape
specifications and also changed the error message. This
lead to failure in the TF.js jax_conversion_test.
We improve the error message and adjust the jax_conversion_test
to match the new message.
PiperOrigin-RevId: 531238700
2023-05-11 10:19:16 -07:00
Anish Tondwalkar
fd3216a41f
Migrate ApproxTopK fallback to StableHLO
...
PiperOrigin-RevId: 531222034
2023-05-11 09:17:58 -07:00
Peter Hawkins
9471bb3045
Disable sparsify_test on CPU under tsan.
...
Under tsan this test times out in CI.
PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
jax authors
6a68750f35
Merge pull request #15958 from mattjj:checkify-closed-call
...
PiperOrigin-RevId: 531098683
2023-05-10 22:30:10 -07:00
Matthew Johnson
f55de18933
[checkify] fix closed_call_p handling
...
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Parker Schuh
261ff9e9ed
Stop passing CompileOptions when deserializing.
...
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
jax authors
74df2d758a
Merge pull request #15603 from mattjj:shmap-call-lowering
...
PiperOrigin-RevId: 530996233
2023-05-10 13:49:51 -07:00
Matthew Johnson
8b66f073d1
[shard-map] experiment with lowering to a Call with attrs
...
Co-authored-by: Bart Chrzaszcz <bartchr@google.com>
2023-05-10 13:14:04 -07:00
jax authors
b0017a7355
Merge pull request #15955 from jakevdp:grad-opaque
...
PiperOrigin-RevId: 530981558
2023-05-10 12:51:35 -07:00
Jake VanderPlas
b250c706b0
Allow opaque dtypes in grad with allow_int=True
2023-05-10 11:43:17 -07:00
jax authors
81a5a5ee52
Merge pull request #15936 from gnecula:poly_vmap_tests
...
PiperOrigin-RevId: 530951808
2023-05-10 10:55:16 -07:00
jax authors
d6828c9c35
Merge pull request #15953 from jakevdp:keyarray-dynamic-slice
...
PiperOrigin-RevId: 530936580
2023-05-10 09:58:49 -07:00
Jake VanderPlas
6ada8785aa
PRNGKeyArray: fix dynamic slice index dtype
2023-05-10 09:24:18 -07:00
jax authors
70f0cc4690
Merge pull request #15944 from mattjj:shmap-remove-cast
...
PiperOrigin-RevId: 530911060
2023-05-10 08:11:19 -07:00
jax authors
538c680e04
Merge pull request #15943 from mattjj:custom-jvp-checkify-symzeros
...
PiperOrigin-RevId: 530907814
2023-05-10 07:56:40 -07:00
George Necula
1429dd5be2
[shape_poly] Remove old test limitations
...
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.
We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
2023-05-10 13:38:24 +02:00
jax authors
48f551378a
Merge pull request #15949 from gnecula:fix_poly
...
PiperOrigin-RevId: 530832168
2023-05-10 00:51:37 -07:00
George Necula
e0518a5154
[shape_poly] Fix shape parsing regression
...
The changes in #15912 inadvertently have dropped some
error checking for the parsed polymorphic specifications.
2023-05-10 09:32:00 +02:00
Anish Tondwalkar
840461673d
Migrate ApproxTopK to StableHLO
...
This uses an ApproxTopK custom-call, which we add support for in supported by
MHLO, by including a lowering to XLA's PartialReduce custom_call via the Client
XLA ApproxTopK function.
PiperOrigin-RevId: 530805966
2023-05-09 22:31:22 -07:00
jax authors
8aa14337e6
Merge pull request #15912 from gnecula:poly_parse
...
PiperOrigin-RevId: 530804516
2023-05-09 22:23:29 -07:00
jax authors
bbc96320ed
Merge pull request #15947 from skye:version
...
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658
Add improved TPU SVD accuracy to the changelog.
...
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f
Update versions and changelog for 0.4.9 release
2023-05-09 17:06:59 -07:00
Yash Katariya
954cda9ce1
Move lint_and_typecheck and documentation job to the ubuntu-latest image since we don't need a large machine for it
...
PiperOrigin-RevId: 530734120
2023-05-09 15:47:22 -07:00
jax authors
1b9180167b
Merge pull request #15945 from skye:version
...
PiperOrigin-RevId: 530722158
jaxlib-v0.4.9
jax-v0.4.9
jax-v0.4.9-rc
2023-05-09 14:59:20 -07:00
Skye Wanderman-Milne
5bcd9dcc46
Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release, take 2
2023-05-09 14:49:54 -07:00
Matthew Johnson
0e14075a35
remove cast
2023-05-09 14:44:05 -07:00
Yash Katariya
befa29b566
Fix the cache on to_gspmd_sharding
to depend on if device/backend is set on pjit/jit.
...
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Matthew Johnson
391e95a683
fix checkify custom_jvp rule to handle symbolic zeros
...
likely broken in #15426 , or maybe not quite right before either
Co-authored-by: Roy Frostig <frostig@google.com>
2023-05-09 14:12:53 -07:00
Yash Katariya
2694bf6207
Use set equality operators instead of intersection because I didn't know set had equality operators.
...
PiperOrigin-RevId: 530688786
2023-05-09 12:55:47 -07:00
jax authors
68ba54241c
Merge pull request #15929 from gnecula:fix_mlir_ir
...
PiperOrigin-RevId: 530675418
2023-05-09 12:02:35 -07:00
Peter Hawkins
a89c377762
[GPU] Fix another instance of missing stream synchronization in RNN kernels.
...
PiperOrigin-RevId: 530660502
2023-05-09 11:08:24 -07:00
jax authors
a2b5bd5230
Merge pull request #15931 from geraschenko:bcoo_reshape
...
PiperOrigin-RevId: 530657565
2023-05-09 10:58:53 -07:00
Anton Geraschenko
27aa5fb774
Make dimensions
argument of bcoo_reshape optional.
2023-05-09 10:38:18 -07:00
Yash Katariya
18d19caa1c
Add McJAX resharding to device_put. Allow resharding if inputs and target sharding have the same set of devices but different order.
...
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.
Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
2023-05-09 09:58:12 -07:00
jax authors
cf4c1edafa
Merge pull request #15920 from froystig:issue15869
...
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
George Necula
daf6a30f6e
Import "ir" directly rather than as "mlir.ir"
2023-05-09 17:55:13 +02:00
jax authors
cb3a4f3dbf
Merge pull request #15859 from gnecula:poly_rng
...
PiperOrigin-RevId: 530606467
2023-05-09 07:43:31 -07:00