rocm_jax/jax/config.py
Matthew Johnson 4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00

155 lines
4.7 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Args:
varname: the name of the variable
default: the default boolean value
Raises: ValueError if the environment variable is anything else.
"""
val = os.getenv(varname, str(default))
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
else:
raise ValueError("invalid truth value %r for environment %r" % (val, varname))
class Config:
def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self.omnistaging_enabled = False
self.omnistaging_enabled = False
self.omnistaging_enablers = []
def update(self, name, val):
if self.use_absl:
setattr(self.absl_flags.FLAGS, name, val)
else:
self.check_exists(name)
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val
def read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
self.check_exists(name)
return self.values[name]
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
if name in self.values:
raise Exception("Config option {} already defined".format(name))
self.values[name] = default
self.meta[name] = (opt_type, meta_args, meta_kwargs)
def check_exists(self, name):
if name not in self.values:
raise AttributeError("Unrecognized config option: {}".format(name))
def DEFINE_bool(self, name, default, *args, **kwargs):
self.add_option(name, default, bool, args, kwargs)
def DEFINE_integer(self, name, default, *args, **kwargs):
self.add_option(name, default, int, args, kwargs)
def DEFINE_string(self, name, default, *args, **kwargs):
self.add_option(name, default, str, args, kwargs)
def DEFINE_enum(self, name, default, *args, **kwargs):
self.add_option(name, default, 'enum', args, kwargs)
def config_with_absl(self):
# Run this before calling `app.run(main)` etc
import absl.flags as absl_FLAGS # noqa: F401
from absl import app, flags as absl_flags
self.use_absl = True
self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer,
str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum }
for name, val in self.values.items():
flag_type, meta_args, meta_kwargs = self.meta[name]
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
def complete_absl_config(self, absl_flags):
for name, _ in self.values.items():
self.update(name, getattr(absl_flags.FLAGS, name))
def parse_flags_with_absl(self):
global already_configured_with_absl
if not already_configured_with_absl:
import absl.flags
self.config_with_absl()
absl.flags.FLAGS(sys.argv, known_only=True)
self.complete_absl_config(absl.flags)
already_configured_with_absl = True
if FLAGS.jax_omnistaging:
self.enable_omnistaging()
# TODO(mattjj): remove this when omnistaging fully lands
def enable_omnistaging(self):
if not self.omnistaging_enabled:
for enabler in self.omnistaging_enablers:
enabler()
self.omnistaging_enabled = True
class NameSpace(object):
def __init__(self, getter):
self._getter = getter
def __getattr__(self, name):
return self._getter(name)
config = Config()
flags = config
FLAGS = flags.FLAGS
already_configured_with_absl = False
flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help='Turn on invariant checking (core.skip_checks = False)'
)
flags.DEFINE_bool(
'jax_omnistaging',
bool_env('JAX_OMNISTAGING', False),
help='Enable staging based on dynamic context rather than data dependence.'
)