12615 Commits

Author SHA1 Message Date
Yash Katariya
33c4fc4fe2 Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
jax authors
0a783ca156 Merge pull request #11836 from froystig:checkpoint-wraps
PiperOrigin-RevId: 466839961
2022-08-10 19:08:32 -07:00
jax authors
4ecd4db27f Merge pull request #11833 from jakevdp:ufunc-tests
PiperOrigin-RevId: 466838445
2022-08-10 19:02:58 -07:00
jax authors
104e87f358 Merge pull request #11838 from froystig:custom-vjp-upgrade-jep-move
PiperOrigin-RevId: 466838413
2022-08-10 18:57:01 -07:00
Roy Frostig
c6ab3a6a60 convert custom VJP update guide to a retroactive JEP 2022-08-10 13:45:57 -07:00
Roy Frostig
7d494a3852 update checkpoint attributes according to functools.wraps
This updates the signature in addition to `__doc__`, and that gets
picked up by generated API docs.
2022-08-10 13:33:07 -07:00
jax authors
70faf30831 Merge pull request #11832 from hawkinsp:py311
PiperOrigin-RevId: 466768048
2022-08-10 13:11:39 -07:00
Bixia Zheng
bb92038b6f Change jax to lower the asin and atan primitives to their corresponding chlo
ops.

PiperOrigin-RevId: 466766999
2022-08-10 13:05:29 -07:00
Jake VanderPlas
97c32f67fc Tests: reenable some ufunc input tests 2022-08-10 12:45:32 -07:00
Peter Hawkins
23f6ef4e6b Fix Python 3.11 compatibility problems.
Also needs https://github.com/tensorflow/tensorflow/pull/57085
2022-08-10 19:45:24 +00:00
jax authors
e81578a9fa Merge pull request #11780 from ROCmSoftwarePlatform:rocm_update_dockerfile
PiperOrigin-RevId: 466756858
2022-08-10 12:19:13 -07:00
jax authors
8dba660b06 Merge pull request #11830 from mattjj:new-remat-landing-srs
PiperOrigin-RevId: 466736309
2022-08-10 11:03:44 -07:00
Matthew Johnson
be6f6bfe9f set new jax.remat / jax.checkpoint to be on-by-default 2022-08-10 10:29:38 -07:00
jax authors
9922308342 Merge pull request #11818 from google:nightly-issue
PiperOrigin-RevId: 466717610
2022-08-10 10:08:35 -07:00
Peter Hawkins
11c3df4635 Disable NumPy dlpack tests on GPU.
NumPy dlpack support only works with CPU buffers.

PiperOrigin-RevId: 466709063
2022-08-10 09:35:41 -07:00
Peter Hawkins
03590d86c0 Disable lax_numpy test that seems to lead to nonterminating LLVM compilation.
PiperOrigin-RevId: 466682802
2022-08-10 07:50:51 -07:00
jax authors
8b2e4f975c Merge pull request #11825 from mattjj:fix-type-annotation
PiperOrigin-RevId: 466550958
2022-08-09 20:21:10 -07:00
jax authors
582968cec5 Merge pull request #11824 from froystig:eltype-slice
PiperOrigin-RevId: 466550346
2022-08-09 20:15:13 -07:00
Matthew Johnson
d76754e40e fix type annotation on remat 2022-08-09 19:57:40 -07:00
Roy Frostig
7955799ae3 defer to custom eltype for slice lowering rule
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
2022-08-09 19:13:34 -07:00
Jake VanderPlas
7ec6acd981 nightly multiprocess test: create issue on failure 2022-08-09 19:12:32 -07:00
Yash Katariya
8a1b4785de Use the same jaxlib package name for nightlies. The __version__ will still contain the dev version (with datetime string in it).
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00
jax authors
169345311a Merge pull request #11807 from pschuh:vmap_sharding_spec
PiperOrigin-RevId: 466472977
2022-08-09 14:42:27 -07:00
Sharad Vikram
3ec1b1b987 Check if XLA Executable has execute_with_token before using it
PiperOrigin-RevId: 466470801
2022-08-09 14:34:57 -07:00
jax authors
6d9512aa39 Merge pull request #11808 from sharadmv:debugger-fix-flatten
PiperOrigin-RevId: 466448203
2022-08-09 13:17:44 -07:00
jax authors
22a4928a74 Merge pull request #11812 from mattjj:fix-ad-checkpoint-traceback-exclusion
PiperOrigin-RevId: 466444604
2022-08-09 13:07:56 -07:00
jax authors
bdc92482aa Merge pull request #11811 from mattjj:remat-landing-fr-fr-fr-fr
PiperOrigin-RevId: 466443765
2022-08-09 13:02:28 -07:00
Matthew Johnson
580e7f39d5 fix traceback exclusion on new checkpoint 2022-08-09 12:45:21 -07:00
Parker Schuh
01df754630
Remove docs 2022-08-09 12:36:49 -07:00
Sharad Vikram
18f164ff1c Try flattening the locals/globals dict in the debugger and have a
fallback if it fails
2022-08-09 12:31:33 -07:00
Matthew Johnson
666f4f838f fix diffrax with new remat, do for cond what #11773 did for while_loop 2022-08-09 12:22:52 -07:00
Rohit Santhanam
1b3542427e [ROCm] Update Dockerfile.rocm. 2022-08-09 11:09:10 -07:00
jax authors
88636d2b64 Merge pull request #11712 from jakevdp:delete-optimizers
PiperOrigin-RevId: 466398079
2022-08-09 10:25:28 -07:00
jax authors
d95b27ce1c Merge pull request #11803 from hawkinsp:multigpu
PiperOrigin-RevId: 466374926
2022-08-09 09:02:21 -07:00
Jake VanderPlas
79406757d0 Remove deprecated jax.experimental.optimizers
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
a2c21958a5 Document multiprocess GPU support.
Fixes #2731
2022-08-09 11:31:05 -04:00
Yash Katariya
ce80a54805 Reshape sharding spec indices to the mesh shape to preserve the old semantics.
PiperOrigin-RevId: 466346873
2022-08-09 06:59:38 -07:00
jax authors
870e8a2928 Merge pull request #11806 from sharadmv:debugger-improvements
PiperOrigin-RevId: 466337260
2022-08-09 06:14:56 -07:00
Marc van Zee
856d91d251 [jax2tf] Fixes a bug in testing.
https://github.com/google/jax/pull/11575 introduced a bug: it catches any exception raised by self.ConvertAndCompare in primitives_test.py here, but does not reraise it if the if-clause is false. This is very dangerous, since any conversion that results in an error (for instance: if an op is missing for enable_xla=False) will now pass.

I had to disable some tests that were introduced somewhere else, most likely in XLA (X64 enabled). I have filed a separate bug for this (internally).

PiperOrigin-RevId: 466269561
2022-08-09 00:34:13 -07:00
Sharad Vikram
c34aa3933f Various debugger improvements
- disables globals
- can opt out of filtering frames
- can limit number of frames
2022-08-08 19:54:29 -07:00
Parker Schuh
8fb957350c Add spmd_axis_name to vmap to allow constraining mapped PartitionSpecs. 2022-08-08 19:41:42 -07:00
Roy Frostig
4c18f1a580 link to both closed and open enhancement proposals
PiperOrigin-RevId: 466212251
2022-08-08 18:48:38 -07:00
jax authors
c4192dca65 Ensure PartitionSpec is picklable.
PartitionSpec instances currently don't roundtrip correctly through pickle, wrapping any value in an extra tuple, e.g.

```
> pickle.loads(pickle.dumps(PartitionSpec('model', 'batch')))
> PartitionSpec(('model', 'batch'),)
```

PiperOrigin-RevId: 466188998
2022-08-08 16:56:45 -07:00
jax authors
38ab3d88ae Merge pull request #11799 from hawkinsp:jep
PiperOrigin-RevId: 466178951
2022-08-08 16:16:18 -07:00
Peter Hawkins
71b29b1cc6 Create JAX Enhancement Proposals (JEPs).
Migrate existing design documents to JEPs.
2022-08-08 16:13:58 -04:00
Parker Schuh
022fedde98 Move protobuf deps to be optional.
PiperOrigin-RevId: 466123822
2022-08-08 12:57:56 -07:00
Yash Katariya
53f16b7709 Calculate indices for pjit/xmap (basically Mesh) cases via op_sharding_to_indices function. Calculation for pmap indices uses the sharding_spec_indices path that it currently takes.
PiperOrigin-RevId: 466120527
2022-08-08 12:45:27 -07:00
jax authors
8dce848f5f Merge pull request #11779 from hawkinsp:npy
PiperOrigin-RevId: 466117679
2022-08-08 12:34:18 -07:00
jax authors
a2a84c40d5 Merge pull request #11797 from jakevdp:matrix-rank-mapped
PiperOrigin-RevId: 466114039
2022-08-08 12:20:41 -07:00
jax authors
ba59b7379b Merge pull request #11796 from sudhakarsingh27:main
PiperOrigin-RevId: 466111110
2022-08-08 12:10:20 -07:00