12232 Commits

Author SHA1 Message Date
Robert Suderman
499a4e733c Expose ml_program dialect for MLIR builder
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
2022-06-28 20:29:41 +00:00
jax authors
90af8e8135 Merge pull request #11278 from sharadmv:for-loop
PiperOrigin-RevId: 457798502
2022-06-28 12:54:30 -07:00
Yash Katariya
6d8c6f8fac Make astype work for Array that are sharded. The current behavior is the same as SDA i.e. it round trips via host.
PiperOrigin-RevId: 457797458
2022-06-28 12:49:12 -07:00
Arjun Sharda
cc8e302933 Update ci-build.yaml
Update ci-build.yaml
2022-06-28 12:45:59 -07:00
Peter Hawkins
dbae3e5ed1 Remove long-deprecated omnistaging flag.
PiperOrigin-RevId: 457794581
2022-06-28 12:35:08 -07:00
Dan F-M
0788d5708a Implementation of jax.scipy.stats.gaussian_kde 2022-06-28 15:17:12 -04:00
jax authors
6835dc18e3 Merge pull request #11259 from jakevdp:average-keepdims
PiperOrigin-RevId: 457777845
2022-06-28 11:20:38 -07:00
jax authors
b62466813b Merge pull request #11287 from jakevdp:fix-choice
PiperOrigin-RevId: 457775054
2022-06-28 11:08:49 -07:00
Jake VanderPlas
df800f39d3 jnp.average: support keepdims argument 2022-06-28 10:55:55 -07:00
Jake VanderPlas
bcb45557c4 random.choice: make compatible with strict promotion 2022-06-28 10:41:30 -07:00
jax authors
4fd6733049 Merge pull request #11289 from jakevdp:fix-mypy
PiperOrigin-RevId: 457750453
2022-06-28 09:33:50 -07:00
Jake VanderPlas
1ccbe0909f CI: fix mypy error 2022-06-28 09:16:36 -07:00
jax authors
b1baf4a343 Merge pull request #11288 from hawkinsp:cpuwarn
PiperOrigin-RevId: 457743040
2022-06-28 09:02:54 -07:00
Peter Hawkins
71f18bec24 Disable test for GPU/TPU warning on Mac.
We previously disabled the GPU/TPU warning on Mac so the test no longer passes. We don't show the warning because we don't support GPUs or TPUs on Mac.
2022-06-28 11:44:25 -04:00
Sharad Vikram
e1ba52bb25 Add tests for jvp(for_loop) 2022-06-27 18:20:32 -07:00
Sharad Vikram
c4b938ffbe Add jvp rule for for_loop
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-06-27 18:20:13 -07:00
John QiangZhang
de464fcf22 update jax2tf README: add walkaround about tf.Module magic conversion.
Here we will use tree_util.flatten and unflatten to provide a general walkwaround
for tfModule Dict->_DictWrapper conversion. It will works for List and Tuple.

PiperOrigin-RevId: 457597147
2022-06-27 16:50:24 -07:00
jax authors
5b464c2956 Merge pull request #11214 from sharadmv:for-loop
PiperOrigin-RevId: 457572546
2022-06-27 14:53:09 -07:00
Sharad Vikram
236a445b49 Add for_loop primitive and impl rule
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-06-27 14:13:31 -07:00
jax authors
bfe9d9f2ed Merge pull request #11274 from google:dependabot/github_actions/styfle/cancel-workflow-action-0.10.0
PiperOrigin-RevId: 457559934
2022-06-27 14:02:11 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
jaxlib-v0.3.14 jax-v0.3.14
2022-06-27 13:56:32 -07:00
Yash Katariya
5b865ed557 Add the assume_metadata option to avoid waiting on ts.open which is very slow on large models. Instead await on the future in the background thread.
PiperOrigin-RevId: 457540435
2022-06-27 12:30:23 -07:00
jax authors
a2f1aee7f3 Merge pull request #11248 from jakevdp:remove-jaxtestcase
PiperOrigin-RevId: 457524591
2022-06-27 11:18:57 -07:00
Jake VanderPlas
887abbc3b9 jax.test_util: remove deprecated test classes.
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
2022-06-27 11:04:50 -07:00
dependabot[bot]
19ae5d8581
Bump styfle/cancel-workflow-action from 0.9.1 to 0.10.0
Bumps [styfle/cancel-workflow-action](https://github.com/styfle/cancel-workflow-action) from 0.9.1 to 0.10.0.
- [Release notes](https://github.com/styfle/cancel-workflow-action/releases)
- [Commits](https://github.com/styfle/cancel-workflow-action/compare/0.9.1...0.10.0)

---
updated-dependencies:
- dependency-name: styfle/cancel-workflow-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2022-06-27 17:17:21 +00:00
jax authors
997beb3ce0 Merge pull request #11273 from hawkinsp:release
PiperOrigin-RevId: 457466335
2022-06-27 06:50:28 -07:00
Peter Hawkins
1e29b7b762 Update CHANGELOG.md and setup.py for 0.3.14 release. 2022-06-27 09:38:41 -04:00
jax authors
02603606e7 Merge pull request #11244 from hawkinsp:xla
PiperOrigin-RevId: 457461421
2022-06-27 06:20:41 -07:00
Peter Hawkins
f4ddd3ef88 Update XLA. 2022-06-27 09:14:28 -04:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
jax authors
93f5113c93 Merge pull request #11250 from LenaMartens:changelist/456788427
PiperOrigin-RevId: 457430831
2022-06-27 03:00:23 -07:00
Lena Martens
740fe6926a Checkify: add (checkify-of-)vmap-of-check. 2022-06-27 10:34:26 +01:00
Peter Whidden
cb137ee885
update bazel on windows docs url 2022-06-25 13:32:16 -04:00
jax authors
406a61cf52 Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
jax authors
62c16da81f Merge pull request #11255 from ikmckenz:fix-broken-links-design-notes
PiperOrigin-RevId: 457101918
2022-06-24 15:08:10 -07:00
Shiva Shahrokhi
df8c6263de Change JAX_PLATFORMS to raise an exception when platform initialization fails 2022-06-24 21:54:53 +00:00
jax authors
4ad0234d85 Merge pull request #11251 from hawkinsp:scipy
PiperOrigin-RevId: 457070748
2022-06-24 12:41:51 -07:00
Ian McKenzie
0cc2ada432 Fix broken links for moved design_notes folder 2022-06-24 12:18:11 -07:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Yash Katariya
989a3304bf Fix the creation of pmap sharding spec when sharded_dim is None.
PiperOrigin-RevId: 457045869
2022-06-24 10:46:35 -07:00
Yash Katariya
e32373c3ea Make jnp.array return jax.Array. Add input and result handlers for jax.Array. Also added tests for add under jit.
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
jax authors
53286a9312 Merge pull request #11247 from hawkinsp:maxwell
PiperOrigin-RevId: 457024985
2022-06-24 09:13:57 -07:00
Peter Hawkins
fc659d5308 Reduce size of double-sided maxwell random test.
It appears that for some inputs this triggers an integer overflow in scipy.stats.maxwell().cdf.
2022-06-24 12:01:20 -04:00
Marc van Zee
4c25ef1d00 Simplifies inverting permutation.
PiperOrigin-RevId: 457013218
2022-06-24 08:06:21 -07:00
jax authors
a90bde2c54 Merge pull request #11231 from hawkinsp:remotetpu
PiperOrigin-RevId: 457005076
2022-06-24 07:13:16 -07:00
jax authors
932f77e3d5 Merge pull request #11226 from gnecula:tf_bug_fix
PiperOrigin-RevId: 456932679
2022-06-23 22:21:06 -07:00
Matthew Johnson
8c5632123b fix ad_util.Zero handling in broadcast_in_dim_jvp_rule
PiperOrigin-RevId: 456922766
2022-06-23 20:54:21 -07:00
jax authors
c9c258ea9b Merge pull request #11215 from jakevdp:roots-jit
PiperOrigin-RevId: 456880017
2022-06-23 15:57:54 -07:00
Matthew Johnson
5f97dc8954 Roll forward with simple fix: handle Zero cotangents in _broadcast_in_dim
transpose rule (previously handled by the deflinear2 wrapper, which it's no
longer using).

PiperOrigin-RevId: 456874635
2022-06-23 15:30:22 -07:00
Jake VanderPlas
f6476f7a03 jnp.roots: better support for computation under JIT 2022-06-23 14:48:53 -07:00