mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add --export_dir
and --baseline_dir
flags to benchmark.py. (#2677)
`--export_dir` allows saving benchmark results to CSV files, and `--baseline_dir` allows comparing results to a baseline exported via `--export_dir`.
This commit is contained in:
parent
afefc927b6
commit
8c2901cf4a
@ -14,16 +14,26 @@
|
||||
"""A simple Python microbenchmarking library."""
|
||||
|
||||
from collections import OrderedDict
|
||||
import csv
|
||||
from numbers import Number
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional, Union, Callable, List, Dict
|
||||
|
||||
from absl import flags
|
||||
import numpy as onp
|
||||
from tabulate import tabulate
|
||||
|
||||
from jax.util import safe_zip
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string(
|
||||
"export_dir", None,
|
||||
"If set, will save results as CSV files in the specified directory.")
|
||||
flags.DEFINE_string(
|
||||
"baseline_dir", None,
|
||||
"If set, include comparison to baseline in results. Baselines should be "
|
||||
"generated with --export_dir and benchmark names are matched to filenames.")
|
||||
|
||||
def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
|
||||
warmup: Optional[int] = None, name: Optional[str] = None,
|
||||
@ -97,14 +107,53 @@ def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],
|
||||
times.append(benchmark(f, name=subname,
|
||||
target_total_secs=target_total_secs))
|
||||
|
||||
print("---------Benchmark summary for %s---------" % name)
|
||||
param_names = list(params_list[0].keys())
|
||||
print(tabulate([tuple(map(_param_str, params.values())) +
|
||||
(t.mean(), _pstd(t), t.mean() / times[0].mean())
|
||||
for params, t in safe_zip(params_list, times)],
|
||||
param_names + ["mean", "%std", "relative"]))
|
||||
data_header = param_names + ["mean", "%std", "relative"]
|
||||
data = [list(map(_param_str, params.values())) +
|
||||
[t.mean(), _pstd(t), t.mean() / times[0].mean()]
|
||||
for params, t in safe_zip(params_list, times)]
|
||||
|
||||
if FLAGS.baseline_dir:
|
||||
mean_idx = len(param_names)
|
||||
means = _get_baseline_means(FLAGS.baseline_dir, name)
|
||||
assert len(means) == len(data), (means, data)
|
||||
data_header.append("mean/baseline")
|
||||
for idx, mean in enumerate(means):
|
||||
data[idx].append(data[idx][mean_idx] / mean)
|
||||
|
||||
print("---------Benchmark summary for %s---------" % name)
|
||||
print(tabulate(data, data_header))
|
||||
print()
|
||||
|
||||
if FLAGS.export_dir:
|
||||
filename = _export_results(data_header, data, FLAGS.export_dir, name)
|
||||
print("Wrote %s results to %s" % (name, filename))
|
||||
print()
|
||||
|
||||
|
||||
def _get_baseline_means(baseline_dir, name):
|
||||
baseline_dir = os.path.expanduser(baseline_dir)
|
||||
filename = os.path.join(baseline_dir, name + ".csv")
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("Can't find baseline file: %s" % filename)
|
||||
with open(filename, newline="") as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
header = next(reader)
|
||||
mean_idx = header.index("mean")
|
||||
return [float(row[mean_idx]) for row in reader]
|
||||
|
||||
|
||||
def _export_results(data_header, data, export_dir, name):
|
||||
assert "mean" in data_header # For future comparisons via _get_baseline_means
|
||||
export_dir = os.path.expanduser(export_dir)
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
filename = os.path.join(export_dir, name + ".csv")
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(data_header)
|
||||
writer.writerows(data)
|
||||
return filename
|
||||
|
||||
|
||||
def _param_str(param):
|
||||
if callable(param):
|
||||
|
Loading…
x
Reference in New Issue
Block a user