16039 Commits

Author SHA1 Message Date
Matthew Johnson
f55de18933 [checkify] fix closed_call_p handling
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
jax authors
74df2d758a Merge pull request #15603 from mattjj:shmap-call-lowering
PiperOrigin-RevId: 530996233
2023-05-10 13:49:51 -07:00
Matthew Johnson
8b66f073d1 [shard-map] experiment with lowering to a Call with attrs
Co-authored-by: Bart Chrzaszcz <bartchr@google.com>
2023-05-10 13:14:04 -07:00
jax authors
b0017a7355 Merge pull request #15955 from jakevdp:grad-opaque
PiperOrigin-RevId: 530981558
2023-05-10 12:51:35 -07:00
Jake VanderPlas
b250c706b0 Allow opaque dtypes in grad with allow_int=True 2023-05-10 11:43:17 -07:00
jax authors
81a5a5ee52 Merge pull request #15936 from gnecula:poly_vmap_tests
PiperOrigin-RevId: 530951808
2023-05-10 10:55:16 -07:00
jax authors
d6828c9c35 Merge pull request #15953 from jakevdp:keyarray-dynamic-slice
PiperOrigin-RevId: 530936580
2023-05-10 09:58:49 -07:00
Jake VanderPlas
6ada8785aa PRNGKeyArray: fix dynamic slice index dtype 2023-05-10 09:24:18 -07:00
jax authors
70f0cc4690 Merge pull request #15944 from mattjj:shmap-remove-cast
PiperOrigin-RevId: 530911060
2023-05-10 08:11:19 -07:00
jax authors
538c680e04 Merge pull request #15943 from mattjj:custom-jvp-checkify-symzeros
PiperOrigin-RevId: 530907814
2023-05-10 07:56:40 -07:00
George Necula
1429dd5be2 [shape_poly] Remove old test limitations
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.

We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
2023-05-10 13:38:24 +02:00
jax authors
48f551378a Merge pull request #15949 from gnecula:fix_poly
PiperOrigin-RevId: 530832168
2023-05-10 00:51:37 -07:00
George Necula
e0518a5154 [shape_poly] Fix shape parsing regression
The changes in #15912 inadvertently have dropped some
error checking for the parsed polymorphic specifications.
2023-05-10 09:32:00 +02:00
Anish Tondwalkar
840461673d Migrate ApproxTopK to StableHLO
This uses an ApproxTopK custom-call, which we add support for in supported by
MHLO, by including a lowering to XLA's PartialReduce custom_call via the Client
XLA ApproxTopK function.

PiperOrigin-RevId: 530805966
2023-05-09 22:31:22 -07:00
jax authors
8aa14337e6 Merge pull request #15912 from gnecula:poly_parse
PiperOrigin-RevId: 530804516
2023-05-09 22:23:29 -07:00
jax authors
bbc96320ed Merge pull request #15947 from skye:version
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658 Add improved TPU SVD accuracy to the changelog.
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Yash Katariya
954cda9ce1 Move lint_and_typecheck and documentation job to the ubuntu-latest image since we don't need a large machine for it
PiperOrigin-RevId: 530734120
2023-05-09 15:47:22 -07:00
jax authors
1b9180167b Merge pull request #15945 from skye:version
PiperOrigin-RevId: 530722158
jaxlib-v0.4.9 jax-v0.4.9 jax-v0.4.9-rc
2023-05-09 14:59:20 -07:00
Skye Wanderman-Milne
5bcd9dcc46 Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release, take 2 2023-05-09 14:49:54 -07:00
Matthew Johnson
0e14075a35 remove cast 2023-05-09 14:44:05 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Matthew Johnson
391e95a683 fix checkify custom_jvp rule to handle symbolic zeros
likely broken in #15426, or maybe not quite right before either

Co-authored-by: Roy Frostig <frostig@google.com>
2023-05-09 14:12:53 -07:00
Yash Katariya
2694bf6207 Use set equality operators instead of intersection because I didn't know set had equality operators.
PiperOrigin-RevId: 530688786
2023-05-09 12:55:47 -07:00
jax authors
68ba54241c Merge pull request #15929 from gnecula:fix_mlir_ir
PiperOrigin-RevId: 530675418
2023-05-09 12:02:35 -07:00
Peter Hawkins
a89c377762 [GPU] Fix another instance of missing stream synchronization in RNN kernels.
PiperOrigin-RevId: 530660502
2023-05-09 11:08:24 -07:00
jax authors
a2b5bd5230 Merge pull request #15931 from geraschenko:bcoo_reshape
PiperOrigin-RevId: 530657565
2023-05-09 10:58:53 -07:00
Anton Geraschenko
27aa5fb774 Make dimensions argument of bcoo_reshape optional. 2023-05-09 10:38:18 -07:00
Yash Katariya
18d19caa1c Add McJAX resharding to device_put. Allow resharding if inputs and target sharding have the same set of devices but different order.
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.

Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
2023-05-09 09:58:12 -07:00
jax authors
cf4c1edafa Merge pull request #15920 from froystig:issue15869
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
George Necula
daf6a30f6e Import "ir" directly rather than as "mlir.ir" 2023-05-09 17:55:13 +02:00
jax authors
cb3a4f3dbf Merge pull request #15859 from gnecula:poly_rng
PiperOrigin-RevId: 530606467
2023-05-09 07:43:31 -07:00
George Necula
de2a811fe9 [shape_poly] Improvements and more testing for shape polymorphism for random primitives
* added support for shape polymorphism for partitionable threefry and for
    random_split.
  * removed footgun that was ignoring the partitionable flag in presence of
    shape polymorphism.
  * Replicated the PRNG tests for threefry (partitionable and non-partitionable),
    and unsafe_rbg.
  * Added general support for overriding jax.config flags for PolyHarness

This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
2023-05-09 13:55:27 +02:00
George Necula
16881e623f [shape_poly] Improve testing of vmap test harnesses
Previously, we disabled `check_result` (check that the JAX native and JAX with shape polymorphism produce the same result) for test harnesses that are created by vmap on primitive harnesses if the primitive harness has a custom assertion.

Now we enable that checking even for those harnesses, and we use the same custom assertion.

PiperOrigin-RevId: 530547784
2023-05-09 02:29:22 -07:00
George Necula
d66be780a5 [shape_poly] Add support for shape polymorphism with native serialization for lax.linalg.qr on TPU
PiperOrigin-RevId: 530522799
2023-05-09 00:25:45 -07:00
Roy Frostig
051c5dda6e delegate select lowering to opaque dtype rule
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
jax authors
236c74cad7 Merge pull request #15909 from skye:version
PiperOrigin-RevId: 530450947
2023-05-08 17:22:23 -07:00
Peter Hawkins
f168a1560c [GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
May fix flaky failures in CI.

Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.

PiperOrigin-RevId: 530425766
2023-05-08 15:37:04 -07:00
Peter Hawkins
00b75aff82 Add tests for negative inputs to top-k.
Make the top-k test inputs larger.

This test would have caught the top-k bug fixed by https://github.com/openxla/xla/pull/2809

PiperOrigin-RevId: 530398528
2023-05-08 13:51:24 -07:00
George Necula
8394d8d6d4 [shape_poly] Improve the parsed for shape polymorphism specifications.
Previously, we used a simple regexp-based parser, which could only
parse additions of multiplications of dimension variables. Now the
symbolic dimension expressions can also contain "mod" and "floordiv"
which would break the parser in confusing ways.

Now we have a recursive-descent parser, with much better error
reporting support.
2023-05-08 22:28:28 +02:00
Peter Hawkins
6b9a109939 Use stream-synchronized copy in rnn_kernels.cc.
May fix flaky wrong outputs sometimes seen in CI.

Also check for errors in another use of gpuStreamSynchronize().

PiperOrigin-RevId: 530391917
2023-05-08 13:28:08 -07:00
George Necula
821b38da12 [shape_poly] Remove old code disabling some tests
PiperOrigin-RevId: 530390362
2023-05-08 13:20:52 -07:00
Skye Wanderman-Milne
5e9364abc6 Revert setup.py changes.
This reverts the setup.py changes from
f28b20175f307d5a56502446a9706480126a5bd4. We actually need to fix some
more issues before releasing 0.4.9, so fix the install at HEAD in the
meantime.
2023-05-08 09:58:51 -07:00
John QiangZhang
47df8628a0 Fix the problem for tf function return StatefulPartitionedCall during jax2tf.call_tf.
PiperOrigin-RevId: 529964653
2023-05-06 08:30:26 -07:00
jax authors
d508f08121 Merge pull request #15893 from froystig:jex-jep
PiperOrigin-RevId: 529812297
2023-05-05 14:17:12 -07:00
Roy Frostig
ce840a9cd8 JEP: jax.extend, a module for extensions 2023-05-05 13:50:22 -07:00
jax authors
a4382d7600 Merge pull request #15890 from jakevdp:delete-slice
PiperOrigin-RevId: 529782623
2023-05-05 12:19:25 -07:00
Jake VanderPlas
882edd4924 jnp.delete: avoid large trace-time constant when deleting slice 2023-05-05 11:38:58 -07:00