rocm_jax/benchmarks/pmap_benchmark.py

157 lines
4.9 KiB
Python
Raw Normal View History

# Copyright 2020 The JAX Authors.
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""To run on CPU with 500 CPU devices:
CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 \
python3 pmap_benchmark.py
To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
"""
from absl import app
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
import jax
from jax import numpy as jnp
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
from jax import pmap
from jax.config import config
from jax._src.util import prod
from benchmarks import benchmark
2020-03-27 10:50:57 -07:00
import numpy as np
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def pmap_shard_sharded_device_array_benchmark():
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
"""Pmap benchmark focusing on shard_args fast path.
This is intended to measure how long it takes to dispatch a correctly-sharded
ShardedDeviceArray to pmap.
"""
def get_benchmark_fn(nargs, nshards):
2020-09-18 09:24:00 -07:00
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
shape = (nshards, 4)
args = [np.random.random(shape) for _ in range(nargs)]
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
sharded_args = pmap(lambda x: x)(args)
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
for arg in sharded_args)
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def benchmark_fn():
for _ in range(100):
pmap_fn(*sharded_args)
return benchmark_fn
params = []
for nargs in (10, 100, 101, 500, 1000, 5000):
nshards = min(8, jax.local_device_count())
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
params.append({"nargs": nargs, "nshards": nshards})
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nargs": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params,
"pmap_shard_sharded_device_array")
def pmap_shard_device_array_benchmark():
"""Pmap benchmark focusing on shard_args DeviceArray path.
This is intended to measure how long it takes to dispatch a DeviceArray to
pmap.
"""
def get_benchmark_fn(nargs, nshards):
2020-09-18 09:24:00 -07:00
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
shape = (nshards, 4)
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
def benchmark_fn():
for _ in range(10):
pmap_fn(*args)
return benchmark_fn
params = []
for nargs in (10, 100, 500):
nshards = min(8, jax.local_device_count())
params.append({"nargs": nargs, "nshards": nshards})
for nshards in (2, 4, 8):
if nshards > jax.local_device_count(): continue
params.append({"nargs": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array")
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def pmap_shard_outputs_benchmark():
"""Pmap benchmark focusing on array_result_handler path.
This is intended to measure how long it takes to construct ShardedDeviceArrays
from pmap.
"""
def get_benchmark_fn(nouts, nshards):
pmap_fn = pmap(lambda x: [x + i for i in range(nouts)])
shape = (nshards, 4)
arg = np.random.random(shape)
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def benchmark_fn():
for _ in range(100):
pmap_fn(arg)
return benchmark_fn
params = []
for nouts in (10, 100, 500, 1000, 5000):
nshards = min(8, jax.local_device_count())
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
params.append({"nouts": nouts, "nshards": nshards})
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nouts": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def sharded_device_array_indexing_benchmark():
"""Benchmark focusing on ShardedDeviceArray indexing."""
def get_benchmark_fn(indices_fn):
nshards = min(8, jax.local_device_count())
shape = (nshards, 8, 8)
def benchmark_fn():
arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape))
indices = indices_fn()
for idx in indices:
arr[idx]
return benchmark_fn
num_internal_iters = 1000
def integer_indices():
return (i for _ in range(num_internal_iters) for i in range(8))
def integer_2D_indices():
return ((i,i) for _ in range(num_internal_iters) for i in range(8))
params = []
params.append({"indices_fn": integer_indices})
params.append({"indices_fn": integer_2D_indices})
benchmark.benchmark_suite(get_benchmark_fn, params,
"ShardedDeviceArray_indexing")
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def run_all_benchmarks():
pmap_shard_sharded_device_array_benchmark()
pmap_shard_device_array_benchmark()
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
pmap_shard_outputs_benchmark()
sharded_device_array_indexing_benchmark()
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
def main(unused_argv):
Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ```
2020-03-17 14:31:25 -07:00
run_all_benchmarks()
if __name__ == "__main__":
config.config_with_absl()
app.run(main)