Yash Katariya
78678ee9e1
Rename count_pjit_cache_miss
with count_pjit_cpp_cache_miss
because it is confusing which cache the first function is taking about as pjit has many caches
...
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
6f2256ad17
Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
...
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
Jake VanderPlas
b37c741c6f
accelerate deprecation of jax.curry
...
PiperOrigin-RevId: 520958381
2023-03-31 10:37:39 -07:00
jax authors
2841bd310e
Merge pull request #15321 from jakevdp:remove-msort
...
PiperOrigin-RevId: 520952178
2023-03-31 10:16:18 -07:00
Zafarali Ahmed
6e00ba8bad
Enable more mesh shape assignment
...
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.
PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
Jake VanderPlas
749dc1b95e
Remove deprecated function jnp.msort
2023-03-31 08:24:36 -07:00
Matthew Johnson
6a2b081506
fix bug from #15335 by checking main_trace tag
2023-03-30 22:35:03 -07:00
Matthew Johnson
211bc29842
add assertions for axis name shadowing bugs
2023-03-30 21:31:02 -07:00
jax authors
f8bff8dd38
Merge pull request #15332 from mattjj:shmap-vmap-closure
...
PiperOrigin-RevId: 520788236
2023-03-30 17:43:02 -07:00
Parker Schuh
0bb46856a8
expose compiler_options
on compile()
...
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 520782460
2023-03-30 17:14:26 -07:00
Matthew Johnson
7c3c46c807
[shard-map] handle closed-over vmap tracers
2023-03-30 16:43:40 -07:00
Yash Katariya
16ca0ca15c
Relax the tolerance of testCauchyLogCdf
...
PiperOrigin-RevId: 520741306
2023-03-30 14:19:50 -07:00
Peter Hawkins
67a28ce30f
Relax test tolerances for testLogisticPpf.
...
Fixes a test failure in CI.
PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
jax authors
dedfc8df75
Merge pull request #15282 from JiaYaobo:geom_random
...
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
794769c113
Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
...
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf
Merge pull request #15297 from jakevdp:finfo-props
...
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
Peter Hawkins
47177e1417
Split more targets out the main JAX Bazel target.
...
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d
improve pmap in_axes/out_axes pytree prefix error messages
2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f
[JAX] Delete _DeviceArray and DeviceArray.
...
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
Jake VanderPlas
5759bf05df
jnp.finfo: add missing properties
2023-03-29 11:23:51 -07:00
jiayaobo
924894e85c
add geometric random gen
...
add geom random
add geom random
add geom random
add geom random
2023-03-30 02:08:04 +08:00
Yash Katariya
fbc05ee5ac
Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
...
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
jax authors
4061bbbbc2
Merge pull request #15269 from skye:min_jaxlib_version
...
PiperOrigin-RevId: 520127548
2023-03-28 14:02:27 -07:00
Skye Wanderman-Milne
00acf459c6
Bump minimum jaxlib version from 0.4.6 to 0.4.7.
...
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
jax authors
bbec461c8b
Merge pull request #15263 from jakevdp:deprecations
...
PiperOrigin-RevId: 520110559
2023-03-28 13:02:32 -07:00
Jake VanderPlas
fc47137ca8
Add deprecation warnings for several top-level jax imports
2023-03-28 12:40:59 -07:00
Yash Katariya
2f105bde2d
Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning
...
PiperOrigin-RevId: 520098757
2023-03-28 12:17:30 -07:00
Jake VanderPlas
ad0fc8979b
jax.scipy.linalg.expm: support batched inputs
2023-03-27 16:39:48 -07:00
Sharad Vikram
10dc941d8d
Add jaxlib version guard for rnn test
...
PiperOrigin-RevId: 519833650
2023-03-27 14:43:46 -07:00
Peter Hawkins
6cc1bf54a1
Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
...
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.
PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
e21aee18a8
Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
...
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Sharad Vikram
3c3fa042e3
Copy seq_lengths before creating descriptor
...
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Yash Katariya
41695cc78c
Temporarily fix the compilation cache test which is failing on latest jaxlib release
...
PiperOrigin-RevId: 519745099
2023-03-27 09:37:37 -07:00
Yash Katariya
a5d308542e
Add src
argument to device_put as an experimental arg
...
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
jax authors
c572155cc1
Merge pull request #15212 from google:pjrt_c_api_tests
...
PiperOrigin-RevId: 519276265
2023-03-24 17:27:40 -07:00
Skye Wanderman-Milne
ef5e4a4035
Remove 'pjrt_c_api_unimplemented' pytest mark.
...
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Anish Tondwalkar
6842e98ca1
Migrate regularized_incomplete_beta_p off xla_fallback
...
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00
Anish Tondwalkar
ac44d2c2e3
Migrate besseli0e off xla_fallback
...
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Peter Hawkins
6ed66ada0f
Delete remote TPU support.
...
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.
PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Parker Schuh
21541e60b1
Guard ArrayImpl checks by xla_extension_version.
...
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
Yash Katariya
bc231ee0bf
After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
...
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
Anish Tondwalkar
8c75e27f67
Migrate random_gamma_grad off xla_fallback
...
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618
Migrate igamma_grad_a_p off xla_fallback
...
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Anish Tondwalkar
4a9b09485e
Migrate igammac_p off xla_fallback path
...
It is now decomposed into stablehlo ops.
PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
jax authors
32e712864c
Merge pull request #15192 from mattjj:issue15190
...
PiperOrigin-RevId: 519037959
2023-03-23 20:48:33 -07:00
jax authors
1982a113d6
Merge pull request #15187 from mattjj:djax-revival
...
PiperOrigin-RevId: 519036576
2023-03-23 20:38:01 -07:00
Matthew Johnson
7743fcd758
[dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit
2023-03-23 20:20:01 -07:00
Matthew Johnson
793387e496
fix jax.Array.round()
...
fixes #15190
2023-03-23 20:16:23 -07:00
Skye Wanderman-Milne
4cb3b011a0
Remove PJRT C API bypass.
...
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.
PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
Peter Hawkins
b7375b316b
Increase minimum NumPy version to 1.21.
...
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00