mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
math_benchmark: add --set_env flag
PiperOrigin-RevId: 515417422
This commit is contained in:
parent
36560538a3
commit
6f1d82916c
@ -19,10 +19,22 @@ import google_benchmark as benchmark
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google_benchmark import Counter
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_multi_string(
|
||||
"set_env", None,
|
||||
"Specifies additional environment variables to be injected into the "
|
||||
"environment (via --set_env=variable=value or --set_env=variable). "
|
||||
"Using this flag is useful when running on remote machines where we do not "
|
||||
"have direct control of the environment except for passing argument flags.")
|
||||
|
||||
def math_benchmark(*args):
|
||||
def decorator(func):
|
||||
for test_case in args[0]:
|
||||
@ -127,6 +139,16 @@ def jax_binary_op(state, **kwargs):
|
||||
state.iterations, Counter.kIsRate
|
||||
)
|
||||
|
||||
def main(argv):
|
||||
if FLAGS.set_env:
|
||||
for env_str in FLAGS.set_env:
|
||||
# Stop matching at the first '=' since we want to capture
|
||||
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
|
||||
env_list = env_str.split('=', 1)
|
||||
if len(env_list) == 2:
|
||||
os.environ[env_list[0]] = env_list[1];
|
||||
benchmark.run_benchmarks()
|
||||
|
||||
if __name__ == '__main__':
|
||||
benchmark.main()
|
||||
sys.argv = benchmark.initialize(sys.argv)
|
||||
app.run(main)
|
||||
|
Loading…
x
Reference in New Issue
Block a user