13978 Commits

Author SHA1 Message Date
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
jax authors
ed9519dadf Merge pull request #13484 from google:index
PiperOrigin-RevId: 492323132
2022-12-01 16:02:50 -08:00
yashkatariya
a2870a182e Add jax.Array to the index page 2022-12-01 15:44:51 -08:00
jax authors
a6c7fa85bb Merge pull request #13483 from google:cross
PiperOrigin-RevId: 492314228
2022-12-01 15:23:19 -08:00
yashkatariya
70d50814b1 Add cross-linking for the migration guide and the parallelism with JAX
tutorial

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-12-01 14:42:59 -08:00
jax authors
8f92d6261c Merge pull request #13465 from jakevdp:x64-lax-numpy-test
PiperOrigin-RevId: 492299616
2022-12-01 14:24:51 -08:00
jax authors
60a45a400b Merge pull request #13480 from skye:doc_url
PiperOrigin-RevId: 492293304
2022-12-01 14:01:16 -08:00
Jake VanderPlas
3cf2924ed6 [x64] minor fixes for lax_numpy_test type safety 2022-12-01 13:56:42 -08:00
jax authors
b3b7eb68f1 Merge pull request #13464 from jakevdp:x64-host-callback
PiperOrigin-RevId: 492284287
2022-12-01 13:25:35 -08:00
Skye Wanderman-Milne
82b442fa52 [docs] Replace one more jax_Array.html reference
I missed this in #13479, thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
jax authors
b2b4fcbeb4 Merge pull request #13479 from skye:doc_url
PiperOrigin-RevId: 492273702
2022-12-01 12:45:40 -08:00
jax authors
89c8df670d Merge pull request #13434 from jurahul:master
PiperOrigin-RevId: 492262252
2022-12-01 12:02:34 -08:00
Skye Wanderman-Milne
51db1cfd0e [docs] Rename "JAX in Parallelism" files so the URL matches the title. 2022-12-01 19:53:31 +00:00
Jake VanderPlas
904398a43d [x64] better type safety for host_callback 2022-12-01 11:47:07 -08:00
Yash Katariya
621322858d Fix vmap(jvp(pjit(f))) when pjit doesn't have any axis_resources
PiperOrigin-RevId: 492238366
2022-12-01 10:42:26 -08:00
jax authors
5847efc96b Merge pull request #13458 from jakevdp:f-strings
PiperOrigin-RevId: 492219655
2022-12-01 09:38:20 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
jax authors
79190b153a Merge pull request #13463 from jakevdp:x64-tests-types
PiperOrigin-RevId: 492211762
2022-12-01 09:12:20 -08:00
jax authors
f0f1d01570 Merge pull request #13470 from gnecula:tf_roll_poly
PiperOrigin-RevId: 492211433
2022-12-01 09:05:24 -08:00
George Necula
fcaf7f1169 [jax2tf] Fix the handling of jnp.roll for polymorphic shapes 2022-12-01 11:18:47 +01:00
George Necula
4ca05f428f [call_tf] Use the same platform for TF lowering as the embedding JAX computation
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.

Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.

PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Yash Katariya
4443b861a5 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Peter Hawkins
e835739eda Remove an unnecessary include/ from pybind11 include paths.
PiperOrigin-RevId: 492016679
2022-11-30 14:20:02 -08:00
jax authors
cfee99e477 Merge pull request #13435 from jakevdp:unused-code
PiperOrigin-RevId: 491985960
2022-11-30 12:18:50 -08:00
jax authors
7b86c2c610 Merge pull request #13430 from jakevdp:remove-dead
PiperOrigin-RevId: 491984563
2022-11-30 12:12:12 -08:00
Jake VanderPlas
94da50d8b3 Cleanup: remove dead code 2022-11-30 12:03:10 -08:00
Jake VanderPlas
0241567c3a remove dead code 2022-11-30 12:02:53 -08:00
jax authors
18f77a526b Add inferReturnTypes for PartitionIdOp.
Same as what already exists for ReplicaIdOp.

PiperOrigin-RevId: 491947476
2022-11-30 10:03:38 -08:00
Yuxin Wu
d5a058c7a8 doc improvement on initilaizer
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
jax authors
07e681be8a Merge pull request #13441 from jakevdp:x64-dtypes-test
PiperOrigin-RevId: 491943702
2022-11-30 09:56:03 -08:00
jax authors
a1120d7910 Merge pull request #13456 from EltayebAhmed:patch-1
PiperOrigin-RevId: 491943659
2022-11-30 09:48:52 -08:00
jax authors
17ec90a5ea Merge pull request #13426 from jakevdp:sparse-tests
PiperOrigin-RevId: 491943629
2022-11-30 09:41:55 -08:00
jax authors
66ad07b5b7 Merge pull request #13442 from jakevdp:x64-api-test
PiperOrigin-RevId: 491933522
2022-11-30 09:04:46 -08:00
Eltayeb Ahmed
b5dc0638a2
Fix typo in docs/multi_process.md 2022-11-30 16:38:27 +00:00
jax authors
7f469afe8a Merge pull request #12877 from LenaMartens:check-error-types
PiperOrigin-RevId: 491915619
2022-11-30 07:52:31 -08:00
Peter Hawkins
6bda0d2863 Don't call dtypes.result_type() unnecessarily on the type of an array during abstractification.
Remove make_shaped_array since it has no more non-test users.

```
name        old cpu/op  new cpu/op  delta
device_put  69.4µs ± 6%  63.5µs ± 3%  -8.56%  (p=0.000 n=10+10)

name        old time/op             new time/op             delta
device_put  69.4µs ± 6%             63.5µs ± 3%  -8.56%        (p=0.000 n=10+10)
```

PiperOrigin-RevId: 491795793
2022-11-29 19:27:10 -08:00
jax authors
22f67d62dc Merge pull request #13440 from froystig:part-rng-parallel-doc
PiperOrigin-RevId: 491791000
2022-11-29 18:50:04 -08:00
jax authors
dba050a9a1 Merge pull request #13288 from imh:betaln_accuracy
PiperOrigin-RevId: 491778273
2022-11-29 17:26:24 -08:00
Yash Katariya
c4d91d203c Remove local_imports of sharding.py. Adding pxla local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py file only.
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support

PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
jax authors
84db79daa6 Merge pull request #13445 from jakevdp:37-todo
PiperOrigin-RevId: 491767813
2022-11-29 16:34:38 -08:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
jax authors
73680b20cc Merge pull request #13443 from jakevdp:drop-37
PiperOrigin-RevId: 491755299
2022-11-29 15:39:37 -08:00
Yash Katariya
ea63baf1ad Remove python 3.7 from testing and release builds
PiperOrigin-RevId: 491752393
2022-11-29 15:26:46 -08:00
Sholto Douglas
92fd8534cd Fixes b/259636412, all-gather failing when called within xmap in pjit. Piece by piece making xmap in pjit work with all collectives so that we can use it to write 'manual kernels' safely!
PiperOrigin-RevId: 491749906
2022-11-29 15:16:23 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
Jake VanderPlas
2f235cd51a [sparse] refactor tests to fix compilation checks 2022-11-29 14:53:13 -08:00
Jake VanderPlas
e916a49d6c [x64] update api_test for type safety 2022-11-29 14:32:15 -08:00
Roy Frostig
b6fd3ff9d7 describe partitionable RNG mode in parallelism doc
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-11-29 14:31:06 -08:00
Jake VanderPlas
a6898c0393 [x64] update dtypes_test for type safety 2022-11-29 14:30:12 -08:00