Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Yash Katariya
6a7a34603d
Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
...
This move is because sharded_jit is being deprecated.
PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Yash Katariya
e08bc27bf0
Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks
flag.
...
PiperOrigin-RevId: 438115859
2022-03-29 13:51:11 -07:00
Yash Katariya
a68b0f3a0a
Fix the mesh_axes of benchmarks to be Pspecs
...
PiperOrigin-RevId: 437932954
2022-03-28 21:51:01 -07:00
Qiao Zhang
5d7f639769
Add small and big matmul to api_benchmarks.
...
name cpu/op
jit_small_matmul 2.96µs ± 2%
jit_big_matmul 22.1µs ±21%
name time/op
jit_small_matmul 2.96µs ± 2%
jit_big_matmul 22.7µs ±21%
PiperOrigin-RevId: 435453853
2022-03-17 14:51:17 -07:00
Yash Katariya
4e47de66fc
Add the cache back now that Mesh's __hash__ is also being hashed on self.devices.shape
.
...
PiperOrigin-RevId: 425711067
2022-02-01 14:06:01 -08:00
Yash Katariya
f80887e69c
Couple of changes because of the serialization inconsistencies being observed.
...
* Remove the cache since one of the keys is global_mesh. Hash of global_mesh doesn't care of the mesh topology but just the devices. This is not good as indices are assigned to devices based on the topology. So if mesh is `(4, 2)` and then you give a new mesh `(2, 4)`, then the cache will return results for `(4, 2)` as the devices haven't changed. This is not right as the indices assigned in mesh `(2, 4)` will be different than `(4, 2)` as the topology is not the same.
```
mesh_devices = np.array(jax.devices()).reshape((4, 2))
mesh_axes = ('x' , 'y')
global_mesh1 = Mesh(mesh_devices, mesh_axes)
print(global_mesh1)
mesh_devices = np.array(jax.devices()).reshape((2, 4))
mesh_axes = ('x' , 'y')
global_mesh2 = Mesh(mesh_devices, mesh_axes)
print(global_mesh2)
hash(global_mesh1) == hash(global_mesh2)
Output:
Mesh(array([[0, 1],
[2, 3],
[4, 5],
[6, 7]]), ('x', 'y'))
Mesh(array([[0, 1, 2, 3],
[4, 5, 6, 7]]), ('x', 'y'))
True
```
* Remove the replica_id calculation math because after that the serialization inconsistencies were resolved. This is still confusing to me as to why its happening since I can't reproduce this in unit tests nor on small model runs. But I'll debug this in parallel. The important thing here is to unblock everyone. Replacing it with _hashed_index is still 2x faster than using _HashableIndex class.
PiperOrigin-RevId: 425525653
2022-01-31 20:48:17 -08:00
Yash Katariya
7f192c1946
Cache the expensive computations in GDA. For example get_shard_indices_replica_ids
can be the same for multiple variables in a neural network (global_shape, mesh_axes and global_mesh) can be the same
...
Note that the first time will be a little slow. The below timings you are seeing shows the caching working because the benchmark is running for multiple iterations and then the time is averaged over the number of iterations.
```
name time/op
gda_construction_callback_(4, 2)_['x', 'y'] 4.50ms ±10%
gda_construction_raw_(256, 8)_['x', 'y'] 5.82ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x', 'y'] 2.95ms ± 6%
indices_replica_id_calc_cached_(256, 8)_['x', 'y'] 28.7µs ± 1%
gda_construction_callback_(4, 2)_[None] 31.9ms ±20%
gda_construction_raw_(256, 8)_[None] 5.85ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_[None] 1.75ms ± 1%
indices_replica_id_calc_cached_(256, 8)_[None] 29.0µs ± 4%
gda_construction_callback_(4, 2)_['x'] 8.40ms ± 4%
gda_construction_raw_(256, 8)_['x'] 5.48ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x'] 1.89ms ± 1%
indices_replica_id_calc_cached_(256, 8)_['x'] 29.0µs ± 4%
gda_construction_callback_(4, 2)_['y'] 15.3ms ± 6%
gda_construction_raw_(256, 8)_['y'] 5.66ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_['y'] 1.82ms ± 2%
indices_replica_id_calc_cached_(256, 8)_['y'] 29.4µs ± 3%
gda_construction_callback_(4, 2)_[('x', 'y')] 4.29ms ± 5%
gda_construction_raw_(256, 8)_[('x', 'y')] 5.61ms ± 7%
indices_replica_id_calc__uncached_(256, 8)_[('x', 'y')] 3.81ms ±10%
indices_replica_id_calc_cached_(256, 8)_[('x', 'y')] 29.0µs ± 5%
gda_construction_raw_(128, 8)_['x', 'y'] 2.42ms ± 1%
indices_replica_id_calc__uncached_(128, 8)_['x', 'y'] 1.14ms ±11%
indices_replica_id_calc_cached_(128, 8)_['x', 'y'] 19.9µs ± 1%
gda_construction_raw_(4, 2)_['x', 'y'] 46.7µs ± 0%
indices_replica_id_calc__uncached_(4, 2)_['x', 'y'] 153µs ± 4%
indices_replica_id_calc_cached_(4, 2)_['x', 'y'] 11.1µs ± 8%
gda_construction_raw_(16, 4)_['x', 'y'] 164µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_['x', 'y'] 212µs ± 3%
indices_replica_id_calc_cached_(16, 4)_['x', 'y'] 11.3µs ± 1%
gda_construction_raw_(16, 4)_[('x', 'y')] 163µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_[('x', 'y')] 210µs ± 2%
indices_replica_id_calc_cached_(16, 4)_[('x', 'y')] 11.6µs ± 8%
```
PiperOrigin-RevId: 422639127
2022-01-18 13:53:52 -08:00
Yash Katariya
0532a63261
Optimizations for GDA to make creating GDA faster.
...
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )
* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.
* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)
```
name old time/op new time/op delta
gda_construction_callback_(4, 2)_['x', 'y'] 4.77ms ± 5% 4.74ms ± 5% ~ (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y'] 17.9ms ± 5% 9.0ms ± 2% -49.92% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y'] 11.4ms ± 2% 2.9ms ± 2% -74.52% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None] 34.0ms ±20% 30.5ms ± 2% ~ (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None] 15.9ms ± 2% 7.7ms ± 3% -51.56% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None] 9.39ms ± 3% 1.74ms ± 2% -81.44% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x'] 8.87ms ± 2% 8.92ms ± 3% ~ (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x'] 16.4ms ± 2% 7.7ms ± 1% -52.66% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x'] 9.85ms ± 1% 1.90ms ± 2% -80.68% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y'] 15.9ms ± 3% 16.0ms ± 5% ~ (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y'] 15.8ms ± 3% 7.6ms ± 1% -52.04% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y'] 9.29ms ± 1% 1.78ms ± 1% -80.79% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')] 4.65ms ± 2% 4.62ms ± 3% ~ (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')] 18.6ms ± 3% 9.7ms ± 5% -47.76% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')] 11.8ms ± 4% 3.5ms ± 2% -70.28% (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y'] 8.54ms ± 1% 4.03ms ± 2% -52.84% (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y'] 5.40ms ± 4% 1.10ms ± 1% -79.69% (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y'] 173µs ± 1% 193µs ± 3% +11.63% (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y'] 127µs ± 1% 147µs ± 1% +15.57% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yash Katariya
fbb8b9f8c6
Benchmarks for GDA. Also move create_global_mesh
to test_utils since it was replicated in a lot of places.
...
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
Tianjian Lu
6a0bfdd637
[JAX] Requires indices to be sorted and of int32 in _sparse_bcoo_matvec
test.
...
PiperOrigin-RevId: 417695937
2021-12-21 14:56:52 -08:00
Jake VanderPlas
b8296bb06c
benchmarks: add compilation of sparse operations
...
PiperOrigin-RevId: 405752775
2021-10-26 15:41:21 -07:00
Jake VanderPlas
c62452f2d2
benchmarks: add JIT versions of sparse.BCOO benchmarks
...
PiperOrigin-RevId: 405696495
2021-10-26 11:39:01 -07:00
Jake VanderPlas
5a3733788c
benchmarks: add sparse.BCOO todense/fromdense/matvec benchmarks
...
PiperOrigin-RevId: 405519824
2021-10-25 16:44:36 -07:00
Yash Katariya
bfbdfa87e7
Add a warmup loop to pmap_simple_8_devices_100_args benchmark so as to not measure the compile time.
...
PiperOrigin-RevId: 401402336
2021-10-06 19:51:35 -07:00
Peter Hawkins
256e7220ff
[JAX] Fix pylint errors.
...
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
6a1b626564
Remove jax.api.
...
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jean-Baptiste Lespiau
6cb8737c1a
Add a benchmark with many arguments.
...
PiperOrigin-RevId: 393216026
2021-08-26 15:06:09 -07:00
Jean-Baptiste Lespiau
afdd195e42
Internal only.
...
PiperOrigin-RevId: 388562206
2021-08-03 15:49:17 -07:00
Qiao Zhang
850bd66242
[JAX] Prune unused inputs in jit.
...
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args
PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
Tom Hennigan
9d56552517
Remove special casing on npy_value when indexing sharded arrays.
...
Before:
```
GPU
sda_index_2 2912972 ns 2778716 ns 256
TPU
sda_index_1 769968 ns 751700 ns 921
sda_index_2 1510841 ns 1489716 ns 465
sda_index_8 6102259 ns 6027655 ns 117
```
After:
```
GPU
sda_index_2 28095 ns 27983 ns 25463
TPU
sda_index_1 10302 ns 10279 ns 67884
sda_index_2 20010 ns 19947 ns 34628
sda_index_8 78492 ns 78306 ns 8934
```
PiperOrigin-RevId: 368380864
2021-04-14 01:17:10 -07:00
Matthew Johnson
9802f3378e
add simple single-primitive eager benchmarks
2021-03-18 21:46:46 -07:00
Peter Hawkins
cdd36b1113
Improve API benchmarks.
...
Add benchmarks for different dispatch arg arities.
Add more blocking before and after benchmark loops that don't otherwise block.
2021-03-03 20:50:45 -05:00
Tom Hennigan
7b94c44af9
Sort imports.
...
PiperOrigin-RevId: 359532656
2021-02-25 08:52:09 -08:00
Peter Hawkins
160dfd343a
Revert import path changes to examples/ and benchmarks/
...
PiperOrigin-RevId: 352911869
2021-01-20 17:35:55 -08:00
Peter Hawkins
929a684a39
Small cleanups to dependency structure.
...
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
Peter Hawkins
3ac809ede3
[JAX] Move jax.util to jax._src_util.
...
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake Vanderplas
6a89f60683
fix benchmark sums ( #4329 )
2020-09-18 09:24:00 -07:00
Jean-Baptiste Lespiau
e95d5701e3
Add benchmarks for specifically the dispatch time. ( #4128 )
...
The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result.
We also add one to distinguish the time it takes to call the function with the argument transfer or without it.
e.g.
name time/op
jit_trivial_dispatch 28.9µs ± 2%
jit_trivial 31.5µs ± 5%
jit_simple_dispatch 60.7µs ± 4%
jit_simple 129µs ±24%
jit_simple_many_args_disptch 390µs ±19%
jit_simple_many_args 388µs ±16%
jit_dispatch_without_transfer 379µs ± 6%
jit_dispatch_with_transfer 450µs ± 5%
2020-08-27 17:02:13 +03:00
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes ( #4086 )
2020-08-18 10:17:38 -07:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp
in library code ( #3754 )
2020-07-14 13:05:31 -07:00
Chris Jones
c1aeb8b3fe
Add simple JAX API microbenchmarks. ( #3679 )
2020-07-09 10:02:23 -07:00
Jake Vanderplas
6aa8f2461c
Fix remaining flakes and use exclude within setup.cfg ( #3371 )
2020-06-08 22:58:03 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. ( #2973 )
...
* Replace np -> jnp, onp -> np in more places.
Context: #2370
* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Skye Wanderman-Milne
4b0334338e
Add pmap_shard_device_array_benchmark. ( #2864 )
...
Also renames pmap_shard_args_benchmark to pmap_shard_sharded_device_array_benchmark.
2020-04-27 17:21:05 -07:00
Skye Wanderman-Milne
8c2901cf4a
Add --export_dir
and --baseline_dir
flags to benchmark.py. ( #2677 )
...
`--export_dir` allows saving benchmark results to CSV files, and
`--baseline_dir` allows comparing results to a baseline exported via
`--export_dir`.
2020-04-13 10:07:05 -07:00
Skye Wanderman-Milne
3fe8bd027c
Adjust pmap_bechmark.py values to be more realistic. ( #2622 )
2020-04-06 16:38:34 -07:00
Skye Wanderman-Milne
c28c46e191
Add ShardedDeviceArray indexing benchmark. ( #2549 )
...
Example output:
```
---------Benchmark summary for ShardedDeviceArray_indexing---------
indices_fn mean %std relative
------------------ -------- ------- ----------
integer_indices 0.16901 8.52522 1
integer_2D_indices 18.4918 0 109.412
```
2020-03-31 15:52:41 -07:00
George Necula
fd52fbf411
Fix import in benchmarks
...
This works on my machine as 'python benchmarks/pmap_benchmark.py'. It also
follows the code in examples.
This will need a copybara rule to change the import to 'from jax.benchmarks import benchmark'
2020-03-31 11:48:08 +03:00
Skye Wanderman-Milne
24bbd2bc1d
Fix pmap_benchmark.py import ( #2524 )
2020-03-27 10:50:57 -07:00
George Necula
428377afb3
Added type annotations and removed unused imports ( #2472 )
...
* Added type annotations and removed unused imports
* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
George Necula
cd7ab0a9e0
Changed to pmap_benchmark to make it runnable in Google ( #2448 )
2020-03-19 06:56:59 +01:00
Skye Wanderman-Milne
75077a1441
Add pmap_benchmark.py ( #2409 )
...
Example output:
```
$ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py
2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
/usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
---------Benchmark results for pmap_shard_args_nargs=10_nshards=4---------
mean=0.034490 std=0.002890 %std=8.378140 total=2.000426
#iters=58 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=100_nshards=4---------
mean=0.091495 std=0.005935 %std=6.486871 total=2.012888
#iters=22 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=101_nshards=4---------
mean=0.113549 std=0.009080 %std=7.996712 total=2.043878
#iters=18 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=500_nshards=4---------
mean=0.356868 std=0.007960 %std=2.230518 total=2.141210
#iters=6 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=10_nshards=2---------
mean=0.022288 std=0.002946 %std=13.219607 total=2.005951
#iters=90 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=10_nshards=4---------
mean=0.035210 std=0.002024 %std=5.747389 total=2.006975
#iters=57 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=10_nshards=8---------
mean=0.048641 std=0.001578 %std=3.243398 total=2.042912
#iters=42 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=10_nshards=100---------
mean=0.257487 std=0.007190 %std=2.792452 total=2.059900
#iters=8 #warmup=1
---------Benchmark results for pmap_shard_args_nargs=10_nshards=500---------
mean=1.696294 std=0.005097 %std=0.300473 total=3.392588
#iters=2 #warmup=1
---------Benchmark summary for pmap_shard_args---------
nargs nshards mean %std relative
------- --------- --------- --------- ----------
10 4 0.0344901 8.37814 1
100 4 0.0914949 6.48687 2.65279
101 4 0.113549 7.99671 3.29221
500 4 0.356868 2.23052 10.347
10 2 0.0222883 13.2196 0.646224
10 4 0.0352101 5.74739 1.02088
10 8 0.0486408 3.2434 1.41028
10 100 0.257487 2.79245 7.46555
10 500 1.69629 0.300473 49.182
---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4---------
mean=0.061780 std=0.004737 %std=7.668032 total=2.038743
#iters=33 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4---------
mean=0.123264 std=0.005980 %std=4.851385 total=2.095494
#iters=17 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4---------
mean=0.471524 std=0.024051 %std=5.100792 total=2.357622
#iters=5 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2---------
mean=0.041546 std=0.004446 %std=10.700256 total=2.035745
#iters=49 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4---------
mean=0.063768 std=0.002756 %std=4.322039 total=2.040561
#iters=32 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8---------
mean=0.087285 std=0.005343 %std=6.121320 total=2.007556
#iters=23 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100---------
mean=0.623440 std=0.004038 %std=0.647725 total=2.493759
#iters=4 #warmup=1
---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500---------
mean=4.096676 std=0.000000 %std=0.000000 total=4.096676
#iters=1 #warmup=1
---------Benchmark summary for pmap_shard_outputs---------
nouts nshards mean %std relative
------- --------- --------- --------- ----------
10 4 0.0617801 7.66803 1
100 4 0.123264 4.85139 1.99521
500 4 0.471524 5.10079 7.6323
10 2 0.0415458 10.7003 0.672479
10 4 0.0637675 4.32204 1.03217
10 8 0.087285 6.12132 1.41283
10 100 0.62344 0.647725 10.0913
10 500 4.09668 0 66.3106
```
2020-03-17 14:31:25 -07:00