5769 Commits

Author SHA1 Message Date
Yash Katariya
78678ee9e1 Rename count_pjit_cache_miss with count_pjit_cpp_cache_miss because it is confusing which cache the first function is taking about as pjit has many caches
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
6f2256ad17 Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
Jake VanderPlas
b37c741c6f accelerate deprecation of jax.curry
PiperOrigin-RevId: 520958381
2023-03-31 10:37:39 -07:00
jax authors
2841bd310e Merge pull request #15321 from jakevdp:remove-msort
PiperOrigin-RevId: 520952178
2023-03-31 10:16:18 -07:00
Zafarali Ahmed
6e00ba8bad Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
Jake VanderPlas
749dc1b95e Remove deprecated function jnp.msort 2023-03-31 08:24:36 -07:00
Matthew Johnson
6a2b081506 fix bug from #15335 by checking main_trace tag 2023-03-30 22:35:03 -07:00
Matthew Johnson
211bc29842 add assertions for axis name shadowing bugs 2023-03-30 21:31:02 -07:00
jax authors
f8bff8dd38 Merge pull request #15332 from mattjj:shmap-vmap-closure
PiperOrigin-RevId: 520788236
2023-03-30 17:43:02 -07:00
Parker Schuh
0bb46856a8 expose compiler_options on compile()
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 520782460
2023-03-30 17:14:26 -07:00
Matthew Johnson
7c3c46c807 [shard-map] handle closed-over vmap tracers 2023-03-30 16:43:40 -07:00
Yash Katariya
16ca0ca15c Relax the tolerance of testCauchyLogCdf
PiperOrigin-RevId: 520741306
2023-03-30 14:19:50 -07:00
Peter Hawkins
67a28ce30f Relax test tolerances for testLogisticPpf.
Fixes a test failure in CI.

PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
jax authors
dedfc8df75 Merge pull request #15282 from JiaYaobo:geom_random
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
794769c113 Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf Merge pull request #15297 from jakevdp:finfo-props
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d improve pmap in_axes/out_axes pytree prefix error messages 2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
Jake VanderPlas
5759bf05df jnp.finfo: add missing properties 2023-03-29 11:23:51 -07:00
jiayaobo
924894e85c add geometric random gen
add geom random

add geom random

add geom random

add geom random
2023-03-30 02:08:04 +08:00
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
jax authors
4061bbbbc2 Merge pull request #15269 from skye:min_jaxlib_version
PiperOrigin-RevId: 520127548
2023-03-28 14:02:27 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
jax authors
bbec461c8b Merge pull request #15263 from jakevdp:deprecations
PiperOrigin-RevId: 520110559
2023-03-28 13:02:32 -07:00
Jake VanderPlas
fc47137ca8 Add deprecation warnings for several top-level jax imports 2023-03-28 12:40:59 -07:00
Yash Katariya
2f105bde2d Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning
PiperOrigin-RevId: 520098757
2023-03-28 12:17:30 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Sharad Vikram
10dc941d8d Add jaxlib version guard for rnn test
PiperOrigin-RevId: 519833650
2023-03-27 14:43:46 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
e21aee18a8 Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Sharad Vikram
3c3fa042e3 Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Yash Katariya
41695cc78c Temporarily fix the compilation cache test which is failing on latest jaxlib release
PiperOrigin-RevId: 519745099
2023-03-27 09:37:37 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
jax authors
c572155cc1 Merge pull request #15212 from google:pjrt_c_api_tests
PiperOrigin-RevId: 519276265
2023-03-24 17:27:40 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Anish Tondwalkar
6842e98ca1 Migrate regularized_incomplete_beta_p off xla_fallback
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00
Anish Tondwalkar
ac44d2c2e3 Migrate besseli0e off xla_fallback
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Parker Schuh
21541e60b1 Guard ArrayImpl checks by xla_extension_version.
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
Yash Katariya
bc231ee0bf After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
Anish Tondwalkar
8c75e27f67 Migrate random_gamma_grad off xla_fallback
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618 Migrate igamma_grad_a_p off xla_fallback
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Anish Tondwalkar
4a9b09485e Migrate igammac_p off xla_fallback path
It is now decomposed into stablehlo ops.

PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
jax authors
32e712864c Merge pull request #15192 from mattjj:issue15190
PiperOrigin-RevId: 519037959
2023-03-23 20:48:33 -07:00
jax authors
1982a113d6 Merge pull request #15187 from mattjj:djax-revival
PiperOrigin-RevId: 519036576
2023-03-23 20:38:01 -07:00
Matthew Johnson
7743fcd758 [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 2023-03-23 20:20:01 -07:00
Matthew Johnson
793387e496 fix jax.Array.round()
fixes #15190
2023-03-23 20:16:23 -07:00
Skye Wanderman-Milne
4cb3b011a0 Remove PJRT C API bypass.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.

PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00