math_benchmark: add --set_env flag

PiperOrigin-RevId: 515417422
This commit is contained in:
Emilio Cota 2023-03-09 12:58:58 -08:00 committed by jax authors
parent 36560538a3
commit 6f1d82916c

View File

@ -19,10 +19,22 @@ import google_benchmark as benchmark
import jax
import jax.numpy as jnp
import numpy as np
import os
import sys
from google_benchmark import Counter
from absl import app
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_multi_string(
"set_env", None,
"Specifies additional environment variables to be injected into the "
"environment (via --set_env=variable=value or --set_env=variable). "
"Using this flag is useful when running on remote machines where we do not "
"have direct control of the environment except for passing argument flags.")
def math_benchmark(*args):
def decorator(func):
for test_case in args[0]:
@ -127,6 +139,16 @@ def jax_binary_op(state, **kwargs):
state.iterations, Counter.kIsRate
)
def main(argv):
if FLAGS.set_env:
for env_str in FLAGS.set_env:
# Stop matching at the first '=' since we want to capture
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
env_list = env_str.split('=', 1)
if len(env_list) == 2:
os.environ[env_list[0]] = env_list[1];
benchmark.run_benchmarks()
if __name__ == '__main__':
benchmark.main()
sys.argv = benchmark.initialize(sys.argv)
app.run(main)