diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py deleted file mode 100644 index f066c6fc2..000000000 --- a/benchmarks/benchmark.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A simple Python microbenchmarking library.""" - -from collections import OrderedDict -import csv -import os -import time -from typing import Any, Optional, Union, Callable, List, Dict - -from absl import flags -import numpy as np -from tabulate import tabulate - -from jax._src.util import safe_zip - -FLAGS = flags.FLAGS -flags.DEFINE_string( - "export_dir", None, - "If set, will save results as CSV files in the specified directory.") -flags.DEFINE_string( - "baseline_dir", None, - "If set, include comparison to baseline in results. Baselines should be " - "generated with --export_dir and benchmark names are matched to filenames.") - -def benchmark(f: Callable[[], Any], iters: Optional[int] = None, - warmup: Optional[int] = None, name: Optional[str] = None, - target_total_secs: Optional[Union[int, float]] = None): - """Benchmarks ``f``. Prints the results and returns the raw times. - - Args: - f: The function to be benchmarked. Should take no arguments. - iters: The number of iterations to run for. If none, runs until - ``target_total_secs`` has elapsed. - warmup: The number of warmup (untimed) iterations to run for. - name: The name of the benchmark. Defaults to f.__name__. - target_total_secs: If ``iters`` isn't specified, the minimum number of - seconds to run for. Defaults to the env var TARGET_TOTAL_SECS or 10 if - not set. - - Returns: - An ndarray containing the number of seconds each iteration ran for. - """ - if target_total_secs is None: - target_total_secs = int(os.getenv("TARGET_TOTAL_SECS", "10")) - - if warmup is None: - if iters is None: - warmup = 1 - else: - warmup = np.clip(1, iters // 10, 10) - for _ in range(warmup): - f() - - times: List[float] = [] - count = 0 - while (count < iters if iters is not None - else sum(times) < target_total_secs): - start = time.time() - f() - end = time.time() - times.append(end - start) - count += 1 - - times_arr = np.array(times) - print("---------Benchmark results for %s---------" % (name or f.__name__)) - print(f"mean={times_arr.mean()} std={times_arr.std()} " - f"%%std={_pstd(times_arr)} total={times_arr.sum()}") - print("#iters=%d #warmup=%d" % (count, warmup)) - print() - return times_arr - - -def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict], - name: str, target_total_secs: Optional[int] = None): - """Benchmarks a function for several combinations of parameters. - - Prints the summarized results in a table.. - - Args: - prepare: given kwargs returns a benchmark function specialized to the kwargs. - params_list: a list of kwargs on which to run the benchmark. - name: the name of this benchmark suite - target_total_secs: the ``target_total_secs`` to pass to ``benchmark``. - """ - # Sort parameters alphabetically so benchmark results print consistently. - params_list = [OrderedDict(sorted(p.items())) for p in params_list] - assert all(p.keys() == params_list[0].keys() for p in params_list) - - times = [] - for params in params_list: - f = prepare(**params) - subname = name + "".join(f"_{n}={_param_str(p)}" - for n, p in params.items()) - times.append(benchmark(f, name=subname, - target_total_secs=target_total_secs)) - - param_names = list(params_list[0].keys()) - data_header = param_names + ["mean", "%std", "relative"] - data = [list(map(_param_str, params.values())) + - [t.mean(), _pstd(t), t.mean() / times[0].mean()] - for params, t in safe_zip(params_list, times)] - - if FLAGS.baseline_dir: - mean_idx = len(param_names) - means = _get_baseline_means(FLAGS.baseline_dir, name) - assert len(means) == len(data), (means, data) - data_header.append("mean/baseline") - for idx, mean in enumerate(means): - data[idx].append(data[idx][mean_idx] / mean) - - print("---------Benchmark summary for %s---------" % name) - print(tabulate(data, data_header)) - print() - - if FLAGS.export_dir: - filename = _export_results(data_header, data, FLAGS.export_dir, name) - print(f"Wrote {name} results to {filename}") - print() - - -def _get_baseline_means(baseline_dir, name): - baseline_dir = os.path.expanduser(baseline_dir) - filename = os.path.join(baseline_dir, name + ".csv") - if not os.path.exists(filename): - raise FileNotFoundError("Can't find baseline file: %s" % filename) - with open(filename, newline="") as csvfile: - reader = csv.reader(csvfile) - header = next(reader) - mean_idx = header.index("mean") - return [float(row[mean_idx]) for row in reader] - - -def _export_results(data_header, data, export_dir, name): - assert "mean" in data_header # For future comparisons via _get_baseline_means - export_dir = os.path.expanduser(export_dir) - os.makedirs(export_dir, exist_ok=True) - filename = os.path.join(export_dir, name + ".csv") - with open(filename, "w", newline="") as csvfile: - writer = csv.writer(csvfile) - writer.writerow(data_header) - writer.writerows(data) - return filename - - -def _param_str(param): - if callable(param): - return param.__name__ - return str(param) - - -def _pstd(x): - return x.std() / x.mean() * 100 diff --git a/benchmarks/pmap_benchmark.py b/benchmarks/pmap_benchmark.py deleted file mode 100644 index 6ca555d3b..000000000 --- a/benchmarks/pmap_benchmark.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""To run on CPU with 500 CPU devices: - -CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 \ -python3 pmap_benchmark.py - -To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2). -""" - -import math - -from absl import app - -import jax -from jax import numpy as jnp -from jax import pmap -from jax.config import config - -from benchmarks import benchmark - -import numpy as np - - -def pmap_shard_sharded_device_array_benchmark(): - """Pmap benchmark focusing on shard_args fast path. - - This is intended to measure how long it takes to dispatch a correctly-sharded - ShardedDeviceArray to pmap. - """ - - def get_benchmark_fn(nargs, nshards): - pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args))) - shape = (nshards, 4) - args = [np.random.random(shape) for _ in range(nargs)] - sharded_args = pmap(lambda x: x)(args) - assert all(isinstance(arg, jax.Array) for arg in sharded_args) - def benchmark_fn(): - for _ in range(100): - pmap_fn(*sharded_args) - return benchmark_fn - - params = [] - for nargs in (10, 100, 101, 500, 1000, 5000): - nshards = min(8, jax.local_device_count()) - params.append({"nargs": nargs, "nshards": nshards}) - for nshards in (2, 4, 8, 100, 500): - if nshards > jax.local_device_count(): continue - params.append({"nargs": 100, "nshards": nshards}) - benchmark.benchmark_suite(get_benchmark_fn, params, - "pmap_shard_sharded_device_array") - - -def pmap_shard_device_array_benchmark(): - """Pmap benchmark focusing on shard_args DeviceArray path. - - This is intended to measure how long it takes to dispatch a DeviceArray to - pmap. - """ - - def get_benchmark_fn(nargs, nshards): - pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args))) - shape = (nshards, 4) - args = [jnp.array(np.random.random(shape)) for _ in range(nargs)] - assert all(isinstance(arg, jax.Array) for arg in args) - def benchmark_fn(): - for _ in range(10): - pmap_fn(*args) - return benchmark_fn - - params = [] - for nargs in (10, 100, 500): - nshards = min(8, jax.local_device_count()) - params.append({"nargs": nargs, "nshards": nshards}) - for nshards in (2, 4, 8): - if nshards > jax.local_device_count(): continue - params.append({"nargs": 100, "nshards": nshards}) - benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array") - - -def pmap_shard_outputs_benchmark(): - """Pmap benchmark focusing on array_result_handler path. - - This is intended to measure how long it takes to construct ShardedDeviceArrays - from pmap. - """ - def get_benchmark_fn(nouts, nshards): - pmap_fn = pmap(lambda x: [x + i for i in range(nouts)]) - shape = (nshards, 4) - arg = np.random.random(shape) - def benchmark_fn(): - for _ in range(100): - pmap_fn(arg) - return benchmark_fn - - params = [] - for nouts in (10, 100, 500, 1000, 5000): - nshards = min(8, jax.local_device_count()) - params.append({"nouts": nouts, "nshards": nshards}) - for nshards in (2, 4, 8, 100, 500): - if nshards > jax.local_device_count(): continue - params.append({"nouts": 100, "nshards": nshards}) - benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs") - - -def sharded_device_array_indexing_benchmark(): - """Benchmark focusing on ShardedDeviceArray indexing.""" - def get_benchmark_fn(indices_fn): - nshards = min(8, jax.local_device_count()) - shape = (nshards, 8, 8) - def benchmark_fn(): - arr = pmap(lambda x: x)(jnp.arange(math.prod(shape)).reshape(shape)) - indices = indices_fn() - for idx in indices: - arr[idx] - return benchmark_fn - - num_internal_iters = 1000 - - def integer_indices(): - return (i for _ in range(num_internal_iters) for i in range(8)) - - def integer_2D_indices(): - return ((i,i) for _ in range(num_internal_iters) for i in range(8)) - - params = [] - params.append({"indices_fn": integer_indices}) - params.append({"indices_fn": integer_2D_indices}) - benchmark.benchmark_suite(get_benchmark_fn, params, - "ShardedDeviceArray_indexing") - - -def run_all_benchmarks(): - pmap_shard_sharded_device_array_benchmark() - pmap_shard_device_array_benchmark() - pmap_shard_outputs_benchmark() - sharded_device_array_indexing_benchmark() - - -def main(unused_argv): - run_all_benchmarks() - - -if __name__ == "__main__": - config.config_with_absl() - app.run(main)