mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

This change, when enabled, stages out all primitive calls in the dynamic scope of a jitted, pmapped, or control flow function, rather than only staging out based on data dependence. One improvement is that jitted functions can consume less memory, by avoiding instantiating large constants at trace time, and cause less memory fragmentation as well. It also simplifies several internals. See https://github.com/google/jax/pull/3370 fo more information.
155 lines
4.7 KiB
Python
155 lines
4.7 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import sys
|
|
|
|
def bool_env(varname: str, default: bool) -> bool:
|
|
"""Read an environment variable and interpret it as a boolean.
|
|
|
|
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
|
|
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
|
|
|
|
Args:
|
|
varname: the name of the variable
|
|
default: the default boolean value
|
|
Raises: ValueError if the environment variable is anything else.
|
|
"""
|
|
val = os.getenv(varname, str(default))
|
|
val = val.lower()
|
|
if val in ('y', 'yes', 't', 'true', 'on', '1'):
|
|
return True
|
|
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
|
|
return False
|
|
else:
|
|
raise ValueError("invalid truth value %r for environment %r" % (val, varname))
|
|
|
|
|
|
class Config:
|
|
def __init__(self):
|
|
self.values = {}
|
|
self.meta = {}
|
|
self.FLAGS = NameSpace(self.read)
|
|
self.use_absl = False
|
|
self.omnistaging_enabled = False
|
|
|
|
self.omnistaging_enabled = False
|
|
self.omnistaging_enablers = []
|
|
|
|
def update(self, name, val):
|
|
if self.use_absl:
|
|
setattr(self.absl_flags.FLAGS, name, val)
|
|
else:
|
|
self.check_exists(name)
|
|
if name not in self.values:
|
|
raise Exception("Unrecognized config option: {}".format(name))
|
|
self.values[name] = val
|
|
|
|
def read(self, name):
|
|
if self.use_absl:
|
|
return getattr(self.absl_flags.FLAGS, name)
|
|
else:
|
|
self.check_exists(name)
|
|
return self.values[name]
|
|
|
|
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
|
|
if name in self.values:
|
|
raise Exception("Config option {} already defined".format(name))
|
|
self.values[name] = default
|
|
self.meta[name] = (opt_type, meta_args, meta_kwargs)
|
|
|
|
def check_exists(self, name):
|
|
if name not in self.values:
|
|
raise AttributeError("Unrecognized config option: {}".format(name))
|
|
|
|
def DEFINE_bool(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, bool, args, kwargs)
|
|
|
|
def DEFINE_integer(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, int, args, kwargs)
|
|
|
|
def DEFINE_string(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, str, args, kwargs)
|
|
|
|
def DEFINE_enum(self, name, default, *args, **kwargs):
|
|
self.add_option(name, default, 'enum', args, kwargs)
|
|
|
|
def config_with_absl(self):
|
|
# Run this before calling `app.run(main)` etc
|
|
import absl.flags as absl_FLAGS # noqa: F401
|
|
from absl import app, flags as absl_flags
|
|
|
|
self.use_absl = True
|
|
self.absl_flags = absl_flags
|
|
absl_defs = { bool: absl_flags.DEFINE_bool,
|
|
int: absl_flags.DEFINE_integer,
|
|
str: absl_flags.DEFINE_string,
|
|
'enum': absl_flags.DEFINE_enum }
|
|
|
|
for name, val in self.values.items():
|
|
flag_type, meta_args, meta_kwargs = self.meta[name]
|
|
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
|
|
|
|
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
|
|
|
|
def complete_absl_config(self, absl_flags):
|
|
for name, _ in self.values.items():
|
|
self.update(name, getattr(absl_flags.FLAGS, name))
|
|
|
|
def parse_flags_with_absl(self):
|
|
global already_configured_with_absl
|
|
if not already_configured_with_absl:
|
|
import absl.flags
|
|
self.config_with_absl()
|
|
absl.flags.FLAGS(sys.argv, known_only=True)
|
|
self.complete_absl_config(absl.flags)
|
|
already_configured_with_absl = True
|
|
|
|
if FLAGS.jax_omnistaging:
|
|
self.enable_omnistaging()
|
|
|
|
# TODO(mattjj): remove this when omnistaging fully lands
|
|
def enable_omnistaging(self):
|
|
if not self.omnistaging_enabled:
|
|
for enabler in self.omnistaging_enablers:
|
|
enabler()
|
|
self.omnistaging_enabled = True
|
|
|
|
|
|
class NameSpace(object):
|
|
def __init__(self, getter):
|
|
self._getter = getter
|
|
|
|
def __getattr__(self, name):
|
|
return self._getter(name)
|
|
|
|
|
|
config = Config()
|
|
flags = config
|
|
FLAGS = flags.FLAGS
|
|
|
|
already_configured_with_absl = False
|
|
|
|
flags.DEFINE_bool(
|
|
'jax_enable_checks',
|
|
bool_env('JAX_ENABLE_CHECKS', False),
|
|
help='Turn on invariant checking (core.skip_checks = False)'
|
|
)
|
|
|
|
flags.DEFINE_bool(
|
|
'jax_omnistaging',
|
|
bool_env('JAX_OMNISTAGING', False),
|
|
help='Enable staging based on dynamic context rather than data dependence.'
|
|
)
|