Yash Katariya
934bc4e1b3
Move PartitionSpec
and Mesh
out of experimental and into the sharding
namespace. The new API endpoint is jax.sharding.PartitionSpec
and jax.sharding.Mesh
.
...
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
jax authors
ed9519dadf
Merge pull request #13484 from google:index
...
PiperOrigin-RevId: 492323132
2022-12-01 16:02:50 -08:00
yashkatariya
a2870a182e
Add jax.Array to the index page
2022-12-01 15:44:51 -08:00
jax authors
a6c7fa85bb
Merge pull request #13483 from google:cross
...
PiperOrigin-RevId: 492314228
2022-12-01 15:23:19 -08:00
yashkatariya
70d50814b1
Add cross-linking for the migration guide and the parallelism with JAX
...
tutorial
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-12-01 14:42:59 -08:00
jax authors
8f92d6261c
Merge pull request #13465 from jakevdp:x64-lax-numpy-test
...
PiperOrigin-RevId: 492299616
2022-12-01 14:24:51 -08:00
jax authors
60a45a400b
Merge pull request #13480 from skye:doc_url
...
PiperOrigin-RevId: 492293304
2022-12-01 14:01:16 -08:00
Jake VanderPlas
3cf2924ed6
[x64] minor fixes for lax_numpy_test type safety
2022-12-01 13:56:42 -08:00
jax authors
b3b7eb68f1
Merge pull request #13464 from jakevdp:x64-host-callback
...
PiperOrigin-RevId: 492284287
2022-12-01 13:25:35 -08:00
Skye Wanderman-Milne
82b442fa52
[docs] Replace one more jax_Array.html reference
...
I missed this in #13479 , thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
jax authors
b2b4fcbeb4
Merge pull request #13479 from skye:doc_url
...
PiperOrigin-RevId: 492273702
2022-12-01 12:45:40 -08:00
jax authors
89c8df670d
Merge pull request #13434 from jurahul:master
...
PiperOrigin-RevId: 492262252
2022-12-01 12:02:34 -08:00
Skye Wanderman-Milne
51db1cfd0e
[docs] Rename "JAX in Parallelism" files so the URL matches the title.
2022-12-01 19:53:31 +00:00
Jake VanderPlas
904398a43d
[x64] better type safety for host_callback
2022-12-01 11:47:07 -08:00
Yash Katariya
621322858d
Fix vmap(jvp(pjit(f)))
when pjit doesn't have any axis_resources
...
PiperOrigin-RevId: 492238366
2022-12-01 10:42:26 -08:00
jax authors
5847efc96b
Merge pull request #13458 from jakevdp:f-strings
...
PiperOrigin-RevId: 492219655
2022-12-01 09:38:20 -08:00
Jake VanderPlas
26d9837b36
Switch to new-style f-strings
2022-12-01 09:14:16 -08:00
jax authors
79190b153a
Merge pull request #13463 from jakevdp:x64-tests-types
...
PiperOrigin-RevId: 492211762
2022-12-01 09:12:20 -08:00
jax authors
f0f1d01570
Merge pull request #13470 from gnecula:tf_roll_poly
...
PiperOrigin-RevId: 492211433
2022-12-01 09:05:24 -08:00
George Necula
fcaf7f1169
[jax2tf] Fix the handling of jnp.roll for polymorphic shapes
2022-12-01 11:18:47 +01:00
George Necula
4ca05f428f
[call_tf] Use the same platform for TF lowering as the embedding JAX computation
...
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.
Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.
PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Yash Katariya
4443b861a5
Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
...
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Jake VanderPlas
f09fd8a4e9
[x64] minor test-only updates for better type safety
2022-11-30 15:18:40 -08:00
Peter Hawkins
e835739eda
Remove an unnecessary include/ from pybind11 include paths.
...
PiperOrigin-RevId: 492016679
2022-11-30 14:20:02 -08:00
jax authors
cfee99e477
Merge pull request #13435 from jakevdp:unused-code
...
PiperOrigin-RevId: 491985960
2022-11-30 12:18:50 -08:00
jax authors
7b86c2c610
Merge pull request #13430 from jakevdp:remove-dead
...
PiperOrigin-RevId: 491984563
2022-11-30 12:12:12 -08:00
Jake VanderPlas
94da50d8b3
Cleanup: remove dead code
2022-11-30 12:03:10 -08:00
Jake VanderPlas
0241567c3a
remove dead code
2022-11-30 12:02:53 -08:00
jax authors
18f77a526b
Add inferReturnTypes
for PartitionIdOp.
...
Same as what already exists for ReplicaIdOp.
PiperOrigin-RevId: 491947476
2022-11-30 10:03:38 -08:00
Yuxin Wu
d5a058c7a8
doc improvement on initilaizer
...
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
jax authors
07e681be8a
Merge pull request #13441 from jakevdp:x64-dtypes-test
...
PiperOrigin-RevId: 491943702
2022-11-30 09:56:03 -08:00
jax authors
a1120d7910
Merge pull request #13456 from EltayebAhmed:patch-1
...
PiperOrigin-RevId: 491943659
2022-11-30 09:48:52 -08:00
jax authors
17ec90a5ea
Merge pull request #13426 from jakevdp:sparse-tests
...
PiperOrigin-RevId: 491943629
2022-11-30 09:41:55 -08:00
jax authors
66ad07b5b7
Merge pull request #13442 from jakevdp:x64-api-test
...
PiperOrigin-RevId: 491933522
2022-11-30 09:04:46 -08:00
Eltayeb Ahmed
b5dc0638a2
Fix typo in docs/multi_process.md
2022-11-30 16:38:27 +00:00
jax authors
7f469afe8a
Merge pull request #12877 from LenaMartens:check-error-types
...
PiperOrigin-RevId: 491915619
2022-11-30 07:52:31 -08:00
Peter Hawkins
6bda0d2863
Don't call dtypes.result_type()
unnecessarily on the type of an array during abstractification.
...
Remove make_shaped_array since it has no more non-test users.
```
name old cpu/op new cpu/op delta
device_put 69.4µs ± 6% 63.5µs ± 3% -8.56% (p=0.000 n=10+10)
name old time/op new time/op delta
device_put 69.4µs ± 6% 63.5µs ± 3% -8.56% (p=0.000 n=10+10)
```
PiperOrigin-RevId: 491795793
2022-11-29 19:27:10 -08:00
jax authors
22f67d62dc
Merge pull request #13440 from froystig:part-rng-parallel-doc
...
PiperOrigin-RevId: 491791000
2022-11-29 18:50:04 -08:00
jax authors
dba050a9a1
Merge pull request #13288 from imh:betaln_accuracy
...
PiperOrigin-RevId: 491778273
2022-11-29 17:26:24 -08:00
Yash Katariya
c4d91d203c
Remove local_imports of sharding.py. Adding pxla
local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py
file only.
...
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support
PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
jax authors
84db79daa6
Merge pull request #13445 from jakevdp:37-todo
...
PiperOrigin-RevId: 491767813
2022-11-29 16:34:38 -08:00
Jake VanderPlas
e7f53479e2
Some cleanups related to dropping Python 3.7
2022-11-29 15:54:49 -08:00
jax authors
73680b20cc
Merge pull request #13443 from jakevdp:drop-37
...
PiperOrigin-RevId: 491755299
2022-11-29 15:39:37 -08:00
Yash Katariya
ea63baf1ad
Remove python 3.7 from testing and release builds
...
PiperOrigin-RevId: 491752393
2022-11-29 15:26:46 -08:00
Sholto Douglas
92fd8534cd
Fixes b/259636412, all-gather failing when called within xmap in pjit. Piece by piece making xmap in pjit work with all collectives so that we can use it to write 'manual kernels' safely!
...
PiperOrigin-RevId: 491749906
2022-11-29 15:16:23 -08:00
Jake VanderPlas
cb62a31653
Drop support for Python 3.7
2022-11-29 15:01:47 -08:00
Jake VanderPlas
2f235cd51a
[sparse] refactor tests to fix compilation checks
2022-11-29 14:53:13 -08:00
Jake VanderPlas
e916a49d6c
[x64] update api_test for type safety
2022-11-29 14:32:15 -08:00
Roy Frostig
b6fd3ff9d7
describe partitionable RNG mode in parallelism doc
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-11-29 14:31:06 -08:00
Jake VanderPlas
a6898c0393
[x64] update dtypes_test for type safety
2022-11-29 14:30:12 -08:00