16662 Commits

Author SHA1 Message Date
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
a29d4bcd33 remove deprecation warning test in preparation for removing deprecated APIs
PiperOrigin-RevId: 547229078
2023-07-11 10:52:10 -07:00
Jake VanderPlas
b581ad1f33 Remove several deprecated jax.Array methods:
- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
2023-07-11 10:34:27 -07:00
Juliana Franco
f81a48a819 Makes it possible to lower primitives with user-defined lowering rules.
PiperOrigin-RevId: 547228102
2023-07-11 10:26:07 -07:00
jax authors
17c4b57f97 Merge pull request #16671 from jakevdp:std-args
PiperOrigin-RevId: 547227744
2023-07-11 10:25:53 -07:00
jax authors
35e0a5fd04 Merge pull request #16682 from hawkinsp:distinct
PiperOrigin-RevId: 547225300
2023-07-11 10:17:05 -07:00
jax authors
949ad1f9d0 Merge pull request #16683 from hawkinsp:win
PiperOrigin-RevId: 547222283
2023-07-11 10:06:45 -07:00
Peter Hawkins
a1a5159fbf Note that the MSVC studio 2019 redistributable is required for JAX on Windows.
Issue #16664
2023-07-11 12:45:13 -04:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
Peter Hawkins
1d4b10b775 Remove --distinct_host_configuration from Bazel flags.
This flag does nothing under Bazel 6 and will be removed in Bazel 7.
2023-07-11 11:38:05 -04:00
jax authors
4c800f5a8a Improve error message to point the way to Megacore.
PiperOrigin-RevId: 547194562
2023-07-11 08:16:33 -07:00
jax authors
2fa6a9c9bf Allow other backends to run the array_test.py test.
PiperOrigin-RevId: 547191886
2023-07-11 08:05:25 -07:00
jax authors
60d481078e Merge pull request #16523 from ROCmSoftwarePlatform:rocm-update-build-doc
PiperOrigin-RevId: 547185925
2023-07-11 07:41:51 -07:00
jax authors
06a7ea91b0 Merge pull request #16491 from hawkinsp:iree
PiperOrigin-RevId: 547184602
2023-07-11 07:32:53 -07:00
jax authors
f7a71e4ca5 Merge pull request #16543 from ROCmSoftwarePlatform:rocm-enable-eighidentity-test
PiperOrigin-RevId: 547179014
2023-07-11 07:13:28 -07:00
Peter Hawkins
3692feb414 Remove the old JAX/IREE integration.
JAX-on-IREE should use the openxla-pjrt-plugin path (https://github.com/openxla/openxla-pjrt-plugin).
2023-07-11 10:10:52 -04:00
jax authors
3ec5f73db0 Merge pull request #16542 from ROCmSoftwarePlatform:rocm-enable-svdontiny-test
PiperOrigin-RevId: 547178719
2023-07-11 07:04:45 -07:00
jax authors
392914bd46 Merge pull request #16677 from froystig:aot-docs
PiperOrigin-RevId: 547078408
2023-07-10 22:38:39 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Roy Frostig
14e38a3f9d AOT doc: fix lower/compile expression in error example 2023-07-10 18:27:06 -07:00
jax authors
ef76ccfc1b Merge pull request #16672 from jakevdp:mapped-apply
PiperOrigin-RevId: 547021542
2023-07-10 16:55:59 -07:00
Jake VanderPlas
1b3da85758 Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
2023-07-10 16:42:45 -07:00
jax authors
f4eed78e90 Merge pull request #16645 from treyra:main
PiperOrigin-RevId: 546987247
2023-07-10 14:41:31 -07:00
Jake VanderPlas
d7bb9f85d6 NumpySignaturesTest: account for 'mean' param to std/var 2023-07-10 09:56:17 -07:00
jax authors
b19c63278d Merge pull request #16668 from LenaMartens:check-cleanup
PiperOrigin-RevId: 546898495
2023-07-10 09:41:50 -07:00
lenamartens
1d5e858ea9 Checkify: remove duplicate line. 2023-07-10 15:41:35 +01:00
treyra
b0c309a25c Added test for vmap inconsistent sized arrays msg 2023-07-09 20:46:40 -07:00
jax authors
1795b12a9f Merge pull request #16654 from jakevdp:ml-dtypes-version
PiperOrigin-RevId: 546366165
2023-07-07 13:13:55 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
jax authors
ded88a83a6 Merge pull request #16541 from axch:ragged-transformer
PiperOrigin-RevId: 546334495
2023-07-07 11:23:59 -07:00
Alexey Radul
defe71228c Clearer test names. 2023-07-07 09:23:33 -04:00
Alexey Radul
5077807c8b Abstract out and reuse the gather_shape_computation to predict which axes will end up ragged.
This should resolve worries about silently wrong metadata about
pile_mapped gather, but gather is complicated so it's hard to be sure.
2023-07-07 09:23:33 -04:00
Alexey Radul
9fdc14f0bf More type annotations, and make transpose_ragged_axes a top-level function instead of a method.
Keep move_stacked_axis as a method because it's a type-specific
version of a top-level function of the same name that already exists.
2023-07-07 09:23:33 -04:00
Alexey Radul
aa3c49f134 Test a different configuration of einsum.
This version stresses my transpose_ragged_axes method, which, it
seems, was interpreting the permutation the wrong way.  Fixed.
2023-07-07 09:23:33 -04:00
Alexey Radul
9d918c4448 Force RaggedAxis axes to be sorted at rest, so that == is more reliable. 2023-07-07 09:23:33 -04:00
Alexey Radul
89dd69ea2d Test and implement ragged slicing.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
  q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:33 -04:00
Alexey Radul
6f09fe840e Better error message when broadcasting ragged to static shape.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:29 -04:00
jax authors
3a0c135b7e Merge pull request #16116 from sharadmv:blah
PiperOrigin-RevId: 546156775
2023-07-06 19:47:10 -07:00
treyra
9ec6ebb7e0 Fixed _mapped_axis_size raising an uncaught TypeError 2023-07-06 18:51:44 -07:00
Sharad Vikram
c446b42522 Add discharge rules for scan/while 2023-07-06 22:30:35 +00:00
jax authors
f08e52faef Merge pull request #16642 from jakevdp:slice-in-dim
PiperOrigin-RevId: 546044892
2023-07-06 11:27:07 -07:00
Jake VanderPlas
7c0334ce15 DOC: improve documentation for lax slicing routines 2023-07-06 10:44:08 -07:00
jax authors
e2478685b9 Merge pull request #16635 from froystig:random-test-cleanups2
PiperOrigin-RevId: 545838671
2023-07-05 18:29:09 -07:00
Roy Frostig
ce9c2d650a rename seed_prng test method to make_key 2023-07-05 15:26:30 -07:00
Roy Frostig
ff70255af9 consistently seed keys indirectly by test class method in LaxRandomTest 2023-07-05 15:18:54 -07:00
Roy Frostig
556c1123cf parameterize two random tests over key constructors 2023-07-05 15:18:54 -07:00
Roy Frostig
c710c7578d move and remove code in random_test 2023-07-05 15:18:54 -07:00
jax authors
7c7051a4cc Merge pull request #16607 from froystig:random-test-double-threefry
PiperOrigin-RevId: 545799083
2023-07-05 15:15:35 -07:00
Roy Frostig
30542bd5bd match behavior of double-threefry test RNG and standard threefry RNG
This also lets us avoid a guard on `config.jax_enable_custom_prng` in
random tests.
2023-07-05 15:01:12 -07:00
jax authors
42f0b21486 Merge pull request #16609 from froystig:random-unsafe-rbg-test-always-typed
PiperOrigin-RevId: 545790077
2023-07-05 14:41:07 -07:00