Parker Schuh
82fcfc3851
Buffer -> Array in some pxla type annotations.
...
PiperOrigin-RevId: 520975371
2023-03-31 11:42:22 -07:00
Jake VanderPlas
b37c741c6f
accelerate deprecation of jax.curry
...
PiperOrigin-RevId: 520958381
2023-03-31 10:37:39 -07:00
jax authors
ffb8352848
Merge pull request #15342 from jakevdp:doc-requirements
...
PiperOrigin-RevId: 520955387
2023-03-31 10:27:30 -07:00
jax authors
2841bd310e
Merge pull request #15321 from jakevdp:remove-msort
...
PiperOrigin-RevId: 520952178
2023-03-31 10:16:18 -07:00
Zafarali Ahmed
6e00ba8bad
Enable more mesh shape assignment
...
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.
PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
jax authors
dfbbc2551c
Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix
...
PiperOrigin-RevId: 520934992
2023-03-31 09:05:07 -07:00
Peter Hawkins
abf1acf76c
Replace references to jax.interpreters with jax._src.interpreters in JAX core.
...
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
jax authors
182cc9857c
Merge pull request #15323 from NeilGirdhar:fix_rayleigh
...
PiperOrigin-RevId: 520932851
2023-03-31 08:50:40 -07:00
Jake VanderPlas
9ec3ad1ce7
DOC: pin newest sphinx-book-theme
2023-03-31 08:42:34 -07:00
Jake VanderPlas
749dc1b95e
Remove deprecated function jnp.msort
2023-03-31 08:24:36 -07:00
jax authors
0df2ddcf0e
Merge pull request #15232 from gnecula:tf_arange
...
PiperOrigin-RevId: 520914838
2023-03-31 07:11:19 -07:00
George Necula
c368c69625
[shape_poly] Extend the handling of jnp.arange with shape polymorphism.
...
Previously, only `arange(stop, dtype=...)` was being handled in presence
of shape polymorphism. Here we extend to add support for `start` and `step`
to be also present. There are still plenty of restrictions:
* no floating point constants are allowed among start, stop and step
* we must resolve statically if step is positive or negative
* we must resolve statically if the distance between start and stop
is negative or positive.
2023-03-31 14:41:26 +02:00
jax authors
76b922aade
Merge pull request #15337 from mattjj:axis-name-shadowing-2
...
PiperOrigin-RevId: 520838748
2023-03-30 23:01:02 -07:00
Matthew Johnson
6a2b081506
fix bug from #15335 by checking main_trace tag
2023-03-30 22:35:03 -07:00
jax authors
12bcdeb69e
Merge pull request #15335 from mattjj:axis-name-shadowing
...
PiperOrigin-RevId: 520829991
2023-03-30 21:56:42 -07:00
Matthew Johnson
211bc29842
add assertions for axis name shadowing bugs
2023-03-30 21:31:02 -07:00
jax authors
d383ab65dc
Merge pull request #15255 from eltociear:patch-6
...
PiperOrigin-RevId: 520814903
2023-03-30 20:38:17 -07:00
jax authors
8e17da477c
Merge pull request #15322 from jakevdp:pre-commit
...
PiperOrigin-RevId: 520793950
2023-03-30 18:26:27 -07:00
jax authors
248ffc2ca2
Merge pull request #15329 from jakevdp:padfunc-protocol-2
...
PiperOrigin-RevId: 520793934
2023-03-30 18:19:43 -07:00
jax authors
61064a1eb5
Merge pull request #15331 from jakevdp:protocol
...
PiperOrigin-RevId: 520793925
2023-03-30 18:12:41 -07:00
jax authors
f8bff8dd38
Merge pull request #15332 from mattjj:shmap-vmap-closure
...
PiperOrigin-RevId: 520788236
2023-03-30 17:43:02 -07:00
Parker Schuh
0bb46856a8
expose compiler_options
on compile()
...
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 520782460
2023-03-30 17:14:26 -07:00
Matthew Johnson
7c3c46c807
[shard-map] handle closed-over vmap tracers
2023-03-30 16:43:40 -07:00
jax authors
d58c970f07
Merge pull request #15327 from jakevdp:fix-user-guides
...
PiperOrigin-RevId: 520767709
2023-03-30 16:14:11 -07:00
jax authors
7d5047f363
Merge pull request #15328 from NeilGirdhar:fix_geom
...
PiperOrigin-RevId: 520767692
2023-03-30 16:07:12 -07:00
Jake VanderPlas
6d006b5994
[typing] use protocol for cumulative reductions
2023-03-30 15:43:43 -07:00
Jake VanderPlas
92386b8524
[typing] use protocol for pad stat_func
2023-03-30 15:07:47 -07:00
Neil Girdhar
78204f7996
Fix broadcasting in jax.random.geometric
2023-03-30 17:54:18 -04:00
Yash Katariya
69c9660aab
Raise deprecation warnings for {in|out}_axis_resources
for pjit and axis_resources
for with_sharding_constraint
...
PiperOrigin-RevId: 520748845
2023-03-30 14:51:01 -07:00
Peter Hawkins
36bf14b044
Remove some dead code.
...
PiperOrigin-RevId: 520746309
2023-03-30 14:41:26 -07:00
Jake VanderPlas
ec63d699e9
DOC: fix headings in user_guides
2023-03-30 14:39:25 -07:00
jax authors
ef29ff166c
Merge pull request #15300 from skye:version
...
PiperOrigin-RevId: 520744521
2023-03-30 14:34:16 -07:00
jax authors
01c6863877
Merge pull request #15301 from google:timeout
...
PiperOrigin-RevId: 520742762
2023-03-30 14:26:50 -07:00
Yash Katariya
16ca0ca15c
Relax the tolerance of testCauchyLogCdf
...
PiperOrigin-RevId: 520741306
2023-03-30 14:19:50 -07:00
Peter Hawkins
31eeaed913
Split mlir.py and xla.py into separate Bazel targets.
...
PiperOrigin-RevId: 520737811
2023-03-30 14:06:16 -07:00
Yash Katariya
c978df5dbb
Delete unused functions from dispatch.py and pjit.py
...
PiperOrigin-RevId: 520730163
2023-03-30 13:38:44 -07:00
Neil Girdhar
1d1b131f4b
Fix broadcasting in jax.random.rayleigh
2023-03-30 16:38:08 -04:00
Jake VanderPlas
b30e6e7d59
CI: add EOF and debug precommit hooks
2023-03-30 13:29:50 -07:00
Peter Hawkins
23451dc764
Merge pull request #15303 from jakevdp:lax-asarray
...
PiperOrigin-RevId: 520717999
2023-03-30 20:11:11 +00:00
Rahul Batra
13e45c8953
[ROCm]: Run pmap test on specific number of GPUs
2023-03-30 18:34:47 +00:00
Jake VanderPlas
8f72454bdf
Add internal jax.lax.asarray utility
2023-03-30 10:21:55 -07:00
Peter Hawkins
67a28ce30f
Relax test tolerances for testLogisticPpf.
...
Fixes a test failure in CI.
PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
jax authors
dedfc8df75
Merge pull request #15282 from JiaYaobo:geom_random
...
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
1fd6e01289
Merge pull request #15287 from gnecula:tf_dim_vars
...
PiperOrigin-RevId: 520633830
2023-03-30 07:37:47 -07:00
jax authors
794769c113
Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
...
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf
Merge pull request #15297 from jakevdp:finfo-props
...
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
George Necula
081b86b82a
[shape_poly] Improved computation of dimension variables for native serialization
...
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)
The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.
This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.
What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.
After we land this, we will refactor the shape environment to be cleaner.
2023-03-30 15:51:24 +02:00
Peter Hawkins
47177e1417
Split more targets out the main JAX Bazel target.
...
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d
improve pmap in_axes/out_axes pytree prefix error messages
2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f
[JAX] Delete _DeviceArray and DeviceArray.
...
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00