Yash Katariya
33c4fc4fe2
Pmap should output SDA like Array
s to maintain the current behavior exactly. Split the shard_arg_handler for Array
based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
...
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
jax authors
0a783ca156
Merge pull request #11836 from froystig:checkpoint-wraps
...
PiperOrigin-RevId: 466839961
2022-08-10 19:08:32 -07:00
jax authors
4ecd4db27f
Merge pull request #11833 from jakevdp:ufunc-tests
...
PiperOrigin-RevId: 466838445
2022-08-10 19:02:58 -07:00
jax authors
104e87f358
Merge pull request #11838 from froystig:custom-vjp-upgrade-jep-move
...
PiperOrigin-RevId: 466838413
2022-08-10 18:57:01 -07:00
Roy Frostig
c6ab3a6a60
convert custom VJP update guide to a retroactive JEP
2022-08-10 13:45:57 -07:00
Roy Frostig
7d494a3852
update checkpoint
attributes according to functools.wraps
...
This updates the signature in addition to `__doc__`, and that gets
picked up by generated API docs.
2022-08-10 13:33:07 -07:00
jax authors
70faf30831
Merge pull request #11832 from hawkinsp:py311
...
PiperOrigin-RevId: 466768048
2022-08-10 13:11:39 -07:00
Bixia Zheng
bb92038b6f
Change jax to lower the asin and atan primitives to their corresponding chlo
...
ops.
PiperOrigin-RevId: 466766999
2022-08-10 13:05:29 -07:00
Jake VanderPlas
97c32f67fc
Tests: reenable some ufunc input tests
2022-08-10 12:45:32 -07:00
Peter Hawkins
23f6ef4e6b
Fix Python 3.11 compatibility problems.
...
Also needs https://github.com/tensorflow/tensorflow/pull/57085
2022-08-10 19:45:24 +00:00
jax authors
e81578a9fa
Merge pull request #11780 from ROCmSoftwarePlatform:rocm_update_dockerfile
...
PiperOrigin-RevId: 466756858
2022-08-10 12:19:13 -07:00
jax authors
8dba660b06
Merge pull request #11830 from mattjj:new-remat-landing-srs
...
PiperOrigin-RevId: 466736309
2022-08-10 11:03:44 -07:00
Matthew Johnson
be6f6bfe9f
set new jax.remat / jax.checkpoint to be on-by-default
2022-08-10 10:29:38 -07:00
jax authors
9922308342
Merge pull request #11818 from google:nightly-issue
...
PiperOrigin-RevId: 466717610
2022-08-10 10:08:35 -07:00
Peter Hawkins
11c3df4635
Disable NumPy dlpack tests on GPU.
...
NumPy dlpack support only works with CPU buffers.
PiperOrigin-RevId: 466709063
2022-08-10 09:35:41 -07:00
Peter Hawkins
03590d86c0
Disable lax_numpy test that seems to lead to nonterminating LLVM compilation.
...
PiperOrigin-RevId: 466682802
2022-08-10 07:50:51 -07:00
jax authors
8b2e4f975c
Merge pull request #11825 from mattjj:fix-type-annotation
...
PiperOrigin-RevId: 466550958
2022-08-09 20:21:10 -07:00
jax authors
582968cec5
Merge pull request #11824 from froystig:eltype-slice
...
PiperOrigin-RevId: 466550346
2022-08-09 20:15:13 -07:00
Matthew Johnson
d76754e40e
fix type annotation on remat
2022-08-09 19:57:40 -07:00
Roy Frostig
7955799ae3
defer to custom eltype for slice lowering rule
...
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
2022-08-09 19:13:34 -07:00
Jake VanderPlas
7ec6acd981
nightly multiprocess test: create issue on failure
2022-08-09 19:12:32 -07:00
Yash Katariya
8a1b4785de
Use the same jaxlib
package name for nightlies. The __version__
will still contain the dev version (with datetime string in it).
...
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00
jax authors
169345311a
Merge pull request #11807 from pschuh:vmap_sharding_spec
...
PiperOrigin-RevId: 466472977
2022-08-09 14:42:27 -07:00
Sharad Vikram
3ec1b1b987
Check if XLA Executable has execute_with_token
before using it
...
PiperOrigin-RevId: 466470801
2022-08-09 14:34:57 -07:00
jax authors
6d9512aa39
Merge pull request #11808 from sharadmv:debugger-fix-flatten
...
PiperOrigin-RevId: 466448203
2022-08-09 13:17:44 -07:00
jax authors
22a4928a74
Merge pull request #11812 from mattjj:fix-ad-checkpoint-traceback-exclusion
...
PiperOrigin-RevId: 466444604
2022-08-09 13:07:56 -07:00
jax authors
bdc92482aa
Merge pull request #11811 from mattjj:remat-landing-fr-fr-fr-fr
...
PiperOrigin-RevId: 466443765
2022-08-09 13:02:28 -07:00
Matthew Johnson
580e7f39d5
fix traceback exclusion on new checkpoint
2022-08-09 12:45:21 -07:00
Parker Schuh
01df754630
Remove docs
2022-08-09 12:36:49 -07:00
Sharad Vikram
18f164ff1c
Try flattening the locals/globals dict in the debugger and have a
...
fallback if it fails
2022-08-09 12:31:33 -07:00
Matthew Johnson
666f4f838f
fix diffrax with new remat, do for cond what #11773 did for while_loop
2022-08-09 12:22:52 -07:00
Rohit Santhanam
1b3542427e
[ROCm] Update Dockerfile.rocm.
2022-08-09 11:09:10 -07:00
jax authors
88636d2b64
Merge pull request #11712 from jakevdp:delete-optimizers
...
PiperOrigin-RevId: 466398079
2022-08-09 10:25:28 -07:00
jax authors
d95b27ce1c
Merge pull request #11803 from hawkinsp:multigpu
...
PiperOrigin-RevId: 466374926
2022-08-09 09:02:21 -07:00
Jake VanderPlas
79406757d0
Remove deprecated jax.experimental.optimizers
...
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
a2c21958a5
Document multiprocess GPU support.
...
Fixes #2731
2022-08-09 11:31:05 -04:00
Yash Katariya
ce80a54805
Reshape sharding spec indices to the mesh shape to preserve the old semantics.
...
PiperOrigin-RevId: 466346873
2022-08-09 06:59:38 -07:00
jax authors
870e8a2928
Merge pull request #11806 from sharadmv:debugger-improvements
...
PiperOrigin-RevId: 466337260
2022-08-09 06:14:56 -07:00
Marc van Zee
856d91d251
[jax2tf] Fixes a bug in testing.
...
https://github.com/google/jax/pull/11575 introduced a bug: it catches any exception raised by self.ConvertAndCompare in primitives_test.py here, but does not reraise it if the if-clause is false. This is very dangerous, since any conversion that results in an error (for instance: if an op is missing for enable_xla=False) will now pass.
I had to disable some tests that were introduced somewhere else, most likely in XLA (X64 enabled). I have filed a separate bug for this (internally).
PiperOrigin-RevId: 466269561
2022-08-09 00:34:13 -07:00
Sharad Vikram
c34aa3933f
Various debugger improvements
...
- disables globals
- can opt out of filtering frames
- can limit number of frames
2022-08-08 19:54:29 -07:00
Parker Schuh
8fb957350c
Add spmd_axis_name
to vmap to allow constraining mapped PartitionSpecs.
2022-08-08 19:41:42 -07:00
Roy Frostig
4c18f1a580
link to both closed and open enhancement proposals
...
PiperOrigin-RevId: 466212251
2022-08-08 18:48:38 -07:00
jax authors
c4192dca65
Ensure PartitionSpec is picklable.
...
PartitionSpec instances currently don't roundtrip correctly through pickle, wrapping any value in an extra tuple, e.g.
```
> pickle.loads(pickle.dumps(PartitionSpec('model', 'batch')))
> PartitionSpec(('model', 'batch'),)
```
PiperOrigin-RevId: 466188998
2022-08-08 16:56:45 -07:00
jax authors
38ab3d88ae
Merge pull request #11799 from hawkinsp:jep
...
PiperOrigin-RevId: 466178951
2022-08-08 16:16:18 -07:00
Peter Hawkins
71b29b1cc6
Create JAX Enhancement Proposals (JEPs).
...
Migrate existing design documents to JEPs.
2022-08-08 16:13:58 -04:00
Parker Schuh
022fedde98
Move protobuf deps to be optional.
...
PiperOrigin-RevId: 466123822
2022-08-08 12:57:56 -07:00
Yash Katariya
53f16b7709
Calculate indices for pjit/xmap (basically Mesh
) cases via op_sharding_to_indices
function. Calculation for pmap indices uses the sharding_spec_indices
path that it currently takes.
...
PiperOrigin-RevId: 466120527
2022-08-08 12:45:27 -07:00
jax authors
8dce848f5f
Merge pull request #11779 from hawkinsp:npy
...
PiperOrigin-RevId: 466117679
2022-08-08 12:34:18 -07:00
jax authors
a2a84c40d5
Merge pull request #11797 from jakevdp:matrix-rank-mapped
...
PiperOrigin-RevId: 466114039
2022-08-08 12:20:41 -07:00
jax authors
ba59b7379b
Merge pull request #11796 from sudhakarsingh27:main
...
PiperOrigin-RevId: 466111110
2022-08-08 12:10:20 -07:00