mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 18:36:07 +00:00
136 lines
4.2 KiB
Python
136 lines
4.2 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import sys
|
|
|
|
def bool_env(varname: str, default: bool) -> bool:
|
|
"""Read an environment variable and interpret it as a boolean.
|
|
|
|
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
|
|
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
|
|
|
|
Args:
|
|
varname: the name of the variable
|
|
default: the default boolean value
|
|
Raises: ValueError if the environment variable is anything else.
|
|
"""
|
|
val = os.getenv(varname, str(default))
|
|
val = val.lower()
|
|
if val in ('y', 'yes', 't', 'true', 'on', '1'):
|
|
return True
|
|
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
|
|
return False
|
|
else:
|
|
raise ValueError("invalid truth value %r for environment %r" % (val, varname))
|
|
|
|
|
|
class Config(object):
|
|
def __init__(self):
|
|
self.values = {}
|
|
self.meta = {}
|
|
self.FLAGS = NameSpace(self.read)
|
|
self.use_absl = False
|
|
|
|
def update(self, name, val):
|
|
if self.use_absl:
|
|
setattr(self.absl_flags.FLAGS, name, val)
|
|
else:
|
|
self.check_exists(name)
|
|
if name not in self.values:
|
|
raise Exception("Unrecognized config option: {}".format(name))
|
|
self.values[name] = val
|
|
|
|
def read(self, name):
|
|
if self.use_absl:
|
|
return getattr(self.absl_flags.FLAGS, name)
|
|
else:
|
|
self.check_exists(name)
|
|
return self.values[name]
|
|
|
|
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
|
|
if name in self.values:
|
|
raise Exception("Config option {} already defined".format(name))
|
|
self.values[name] = default
|
|
self.meta[name] = (opt_type, meta_args, meta_kwargs)
|
|
|
|
def check_exists(self, name):
|
|
if name not in self.values:
|
|
raise AttributeError("Unrecognized config option: {}".format(name))
|
|
|
|
def DEFINE_bool(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, bool, args, kwargs)
|
|
|
|
def DEFINE_integer(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, int, args, kwargs)
|
|
|
|
def DEFINE_string(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, str, args, kwargs)
|
|
|
|
def DEFINE_enum(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, 'enum', args, kwargs)
|
|
|
|
def config_with_absl(self):
|
|
# Run this before calling `app.run(main)` etc
|
|
import absl.flags as absl_FLAGS
|
|
from absl import app, flags as absl_flags
|
|
|
|
self.use_absl = True
|
|
self.absl_flags = absl_flags
|
|
absl_defs = { bool: absl_flags.DEFINE_bool,
|
|
int: absl_flags.DEFINE_integer,
|
|
str: absl_flags.DEFINE_string,
|
|
'enum': absl_flags.DEFINE_enum }
|
|
|
|
for name, val in self.values.items():
|
|
flag_type, meta_args, meta_kwargs = self.meta[name]
|
|
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
|
|
|
|
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
|
|
|
|
def complete_absl_config(self, absl_flags):
|
|
for name, _ in self.values.items():
|
|
self.update(name, getattr(absl_flags.FLAGS, name))
|
|
|
|
def parse_flags_with_absl(self):
|
|
global already_configured_with_absl
|
|
if not already_configured_with_absl:
|
|
import absl.flags
|
|
self.config_with_absl()
|
|
absl.flags.FLAGS(sys.argv, known_only=True)
|
|
self.complete_absl_config(absl.flags)
|
|
already_configured_with_absl = True
|
|
|
|
|
|
class NameSpace(object):
|
|
def __init__(self, getter):
|
|
self._getter = getter
|
|
|
|
def __getattr__(self, name):
|
|
return self._getter(name)
|
|
|
|
|
|
config = Config()
|
|
flags = config
|
|
FLAGS = flags.FLAGS
|
|
|
|
already_configured_with_absl = False
|
|
|
|
flags.DEFINE_bool(
|
|
'jax_enable_checks',
|
|
bool_env('JAX_ENABLE_CHECKS', False),
|
|
help=
|
|
'Turn on invariant checking (core.skip_checks = False)'
|
|
)
|