diff --git a/benchmarks/math_benchmark.py b/benchmarks/math_benchmark.py index 4a0a0b7a6..74674c159 100644 --- a/benchmarks/math_benchmark.py +++ b/benchmarks/math_benchmark.py @@ -19,10 +19,22 @@ import google_benchmark as benchmark import jax import jax.numpy as jnp import numpy as np +import os +import sys from google_benchmark import Counter +from absl import app +from absl import flags +FLAGS = flags.FLAGS +flags.DEFINE_multi_string( + "set_env", None, + "Specifies additional environment variables to be injected into the " + "environment (via --set_env=variable=value or --set_env=variable). " + "Using this flag is useful when running on remote machines where we do not " + "have direct control of the environment except for passing argument flags.") + def math_benchmark(*args): def decorator(func): for test_case in args[0]: @@ -127,6 +139,16 @@ def jax_binary_op(state, **kwargs): state.iterations, Counter.kIsRate ) +def main(argv): + if FLAGS.set_env: + for env_str in FLAGS.set_env: + # Stop matching at the first '=' since we want to capture + # --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO. + env_list = env_str.split('=', 1) + if len(env_list) == 2: + os.environ[env_list[0]] = env_list[1]; + benchmark.run_benchmarks() if __name__ == '__main__': - benchmark.main() + sys.argv = benchmark.initialize(sys.argv) + app.run(main)