15546 Commits

Author SHA1 Message Date
Parker Schuh
82fcfc3851 Buffer -> Array in some pxla type annotations.
PiperOrigin-RevId: 520975371
2023-03-31 11:42:22 -07:00
Jake VanderPlas
b37c741c6f accelerate deprecation of jax.curry
PiperOrigin-RevId: 520958381
2023-03-31 10:37:39 -07:00
jax authors
ffb8352848 Merge pull request #15342 from jakevdp:doc-requirements
PiperOrigin-RevId: 520955387
2023-03-31 10:27:30 -07:00
jax authors
2841bd310e Merge pull request #15321 from jakevdp:remove-msort
PiperOrigin-RevId: 520952178
2023-03-31 10:16:18 -07:00
Zafarali Ahmed
6e00ba8bad Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
jax authors
dfbbc2551c Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix
PiperOrigin-RevId: 520934992
2023-03-31 09:05:07 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
jax authors
182cc9857c Merge pull request #15323 from NeilGirdhar:fix_rayleigh
PiperOrigin-RevId: 520932851
2023-03-31 08:50:40 -07:00
Jake VanderPlas
9ec3ad1ce7 DOC: pin newest sphinx-book-theme 2023-03-31 08:42:34 -07:00
Jake VanderPlas
749dc1b95e Remove deprecated function jnp.msort 2023-03-31 08:24:36 -07:00
jax authors
0df2ddcf0e Merge pull request #15232 from gnecula:tf_arange
PiperOrigin-RevId: 520914838
2023-03-31 07:11:19 -07:00
George Necula
c368c69625 [shape_poly] Extend the handling of jnp.arange with shape polymorphism.
Previously, only `arange(stop, dtype=...)` was being handled in presence
of shape polymorphism. Here we extend to add support for `start` and `step`
to be also present. There are still plenty of restrictions:

   * no floating point constants are allowed among start, stop and step
   * we must resolve statically if step is positive or negative
   * we must resolve statically if the distance between start and stop
     is negative or positive.
2023-03-31 14:41:26 +02:00
jax authors
76b922aade Merge pull request #15337 from mattjj:axis-name-shadowing-2
PiperOrigin-RevId: 520838748
2023-03-30 23:01:02 -07:00
Matthew Johnson
6a2b081506 fix bug from #15335 by checking main_trace tag 2023-03-30 22:35:03 -07:00
jax authors
12bcdeb69e Merge pull request #15335 from mattjj:axis-name-shadowing
PiperOrigin-RevId: 520829991
2023-03-30 21:56:42 -07:00
Matthew Johnson
211bc29842 add assertions for axis name shadowing bugs 2023-03-30 21:31:02 -07:00
jax authors
d383ab65dc Merge pull request #15255 from eltociear:patch-6
PiperOrigin-RevId: 520814903
2023-03-30 20:38:17 -07:00
jax authors
8e17da477c Merge pull request #15322 from jakevdp:pre-commit
PiperOrigin-RevId: 520793950
2023-03-30 18:26:27 -07:00
jax authors
248ffc2ca2 Merge pull request #15329 from jakevdp:padfunc-protocol-2
PiperOrigin-RevId: 520793934
2023-03-30 18:19:43 -07:00
jax authors
61064a1eb5 Merge pull request #15331 from jakevdp:protocol
PiperOrigin-RevId: 520793925
2023-03-30 18:12:41 -07:00
jax authors
f8bff8dd38 Merge pull request #15332 from mattjj:shmap-vmap-closure
PiperOrigin-RevId: 520788236
2023-03-30 17:43:02 -07:00
Parker Schuh
0bb46856a8 expose compiler_options on compile()
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 520782460
2023-03-30 17:14:26 -07:00
Matthew Johnson
7c3c46c807 [shard-map] handle closed-over vmap tracers 2023-03-30 16:43:40 -07:00
jax authors
d58c970f07 Merge pull request #15327 from jakevdp:fix-user-guides
PiperOrigin-RevId: 520767709
2023-03-30 16:14:11 -07:00
jax authors
7d5047f363 Merge pull request #15328 from NeilGirdhar:fix_geom
PiperOrigin-RevId: 520767692
2023-03-30 16:07:12 -07:00
Jake VanderPlas
6d006b5994 [typing] use protocol for cumulative reductions 2023-03-30 15:43:43 -07:00
Jake VanderPlas
92386b8524 [typing] use protocol for pad stat_func 2023-03-30 15:07:47 -07:00
Neil Girdhar
78204f7996 Fix broadcasting in jax.random.geometric 2023-03-30 17:54:18 -04:00
Yash Katariya
69c9660aab Raise deprecation warnings for {in|out}_axis_resources for pjit and axis_resources for with_sharding_constraint
PiperOrigin-RevId: 520748845
2023-03-30 14:51:01 -07:00
Peter Hawkins
36bf14b044 Remove some dead code.
PiperOrigin-RevId: 520746309
2023-03-30 14:41:26 -07:00
Jake VanderPlas
ec63d699e9 DOC: fix headings in user_guides 2023-03-30 14:39:25 -07:00
jax authors
ef29ff166c Merge pull request #15300 from skye:version
PiperOrigin-RevId: 520744521
2023-03-30 14:34:16 -07:00
jax authors
01c6863877 Merge pull request #15301 from google:timeout
PiperOrigin-RevId: 520742762
2023-03-30 14:26:50 -07:00
Yash Katariya
16ca0ca15c Relax the tolerance of testCauchyLogCdf
PiperOrigin-RevId: 520741306
2023-03-30 14:19:50 -07:00
Peter Hawkins
31eeaed913 Split mlir.py and xla.py into separate Bazel targets.
PiperOrigin-RevId: 520737811
2023-03-30 14:06:16 -07:00
Yash Katariya
c978df5dbb Delete unused functions from dispatch.py and pjit.py
PiperOrigin-RevId: 520730163
2023-03-30 13:38:44 -07:00
Neil Girdhar
1d1b131f4b Fix broadcasting in jax.random.rayleigh 2023-03-30 16:38:08 -04:00
Jake VanderPlas
b30e6e7d59 CI: add EOF and debug precommit hooks 2023-03-30 13:29:50 -07:00
Peter Hawkins
23451dc764 Merge pull request #15303 from jakevdp:lax-asarray
PiperOrigin-RevId: 520717999
2023-03-30 20:11:11 +00:00
Rahul Batra
13e45c8953 [ROCm]: Run pmap test on specific number of GPUs 2023-03-30 18:34:47 +00:00
Jake VanderPlas
8f72454bdf Add internal jax.lax.asarray utility 2023-03-30 10:21:55 -07:00
Peter Hawkins
67a28ce30f Relax test tolerances for testLogisticPpf.
Fixes a test failure in CI.

PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
jax authors
dedfc8df75 Merge pull request #15282 from JiaYaobo:geom_random
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
1fd6e01289 Merge pull request #15287 from gnecula:tf_dim_vars
PiperOrigin-RevId: 520633830
2023-03-30 07:37:47 -07:00
jax authors
794769c113 Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf Merge pull request #15297 from jakevdp:finfo-props
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
George Necula
081b86b82a [shape_poly] Improved computation of dimension variables for native serialization
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)

The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.

This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.

What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.

After we land this, we will refactor the shape environment to be cleaner.
2023-03-30 15:51:24 +02:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d improve pmap in_axes/out_axes pytree prefix error messages 2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00