# Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TODO(phawkins): this file triggers a pytype bug. # pytype: skip-file import contextlib import functools import itertools import os import sys import threading from typing import Any, List, Callable, NamedTuple, Optional import warnings from jax._src import lib from jax._src.lib import jax_jit def bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Args: varname: the name of the variable default: the default boolean value Raises: ValueError if the environment variable is anything else. """ val = os.getenv(varname, str(default)) val = val.lower() if val in ('y', 'yes', 't', 'true', 'on', '1'): return True elif val in ('n', 'no', 'f', 'false', 'off', '0'): return False else: raise ValueError("invalid truth value %r for environment %r" % (val, varname)) def int_env(varname: str, default: int) -> int: """Read an environment variable and interpret it as an integer.""" return int(os.getenv(varname, str(default))) class Config: _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self): self.values = {} self.meta = {} self.FLAGS = NameSpace(self.read, self.update) self.use_absl = False self._contextmanager_flags = set() self._update_hooks = {} self.omnistaging_enabled = True # TODO(mattjj): remove this def update(self, name, val): if self.use_absl: setattr(self.absl_flags.FLAGS, name, val) else: self.check_exists(name) if name not in self.values: raise Exception("Unrecognized config option: {}".format(name)) self.values[name] = val hook = self._update_hooks.get(name, None) if hook: hook(val) def read(self, name): if name in self._contextmanager_flags: raise AttributeError( "For flags with a corresponding contextmanager, read their value " f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.") return self._read(name) def _read(self, name): if self.use_absl: return getattr(self.absl_flags.FLAGS, name) else: self.check_exists(name) return self.values[name] def add_option(self, name, default, opt_type, meta_args, meta_kwargs, update_hook=None): if name in self.values: raise Exception("Config option {} already defined".format(name)) self.values[name] = default self.meta[name] = (opt_type, meta_args, meta_kwargs) if update_hook: self._update_hooks[name] = update_hook update_hook(default) def check_exists(self, name): if name not in self.values: raise AttributeError("Unrecognized config option: {}".format(name)) def DEFINE_bool(self, name, default, *args, **kwargs): update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, bool, args, kwargs, update_hook=update_hook) def DEFINE_integer(self, name, default, *args, **kwargs): update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, int, args, kwargs, update_hook=update_hook) def DEFINE_string(self, name, default, *args, **kwargs): update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, str, args, kwargs, update_hook=update_hook) def DEFINE_enum(self, name, default, *args, **kwargs): update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, 'enum', args, kwargs, update_hook=update_hook) def config_with_absl(self): # Run this before calling `app.run(main)` etc import absl.flags as absl_FLAGS # noqa: F401 from absl import app, flags as absl_flags self.use_absl = True self.absl_flags = absl_flags absl_defs = { bool: absl_flags.DEFINE_bool, int: absl_flags.DEFINE_integer, str: absl_flags.DEFINE_string, 'enum': absl_flags.DEFINE_enum } for name, val in self.values.items(): flag_type, meta_args, meta_kwargs = self.meta[name] absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) app.call_after_init(lambda: self.complete_absl_config(absl_flags)) def complete_absl_config(self, absl_flags): for name, _ in self.values.items(): self.update(name, getattr(absl_flags.FLAGS, name)) def parse_flags_with_absl(self): global already_configured_with_absl if not already_configured_with_absl: # Extract just the --jax... flags (before the first --) from argv. In some # environments (e.g. ipython/colab) argv might be a mess of things # parseable by absl and other junk. jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv) jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))] import absl.flags self.config_with_absl() absl.flags.FLAGS(jax_argv, known_only=True) self.complete_absl_config(absl.flags) already_configured_with_absl = True if not FLAGS.jax_omnistaging: raise Exception( "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: " "see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.\n" "To remove this warning, unset the JAX_OMNISTAGING environment variable.") def enable_omnistaging(self): warnings.warn( "enable_omnistaging() is a no-op in JAX versions 0.2.12 and higher;\n" "see https://github.com/google/jax/blob/main/design_notes/omnistaging.md") def disable_omnistaging(self): raise Exception( "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: " "see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.") def define_bool_state( self, name: str, default: bool, help: str, *, update_global_hook: Optional[Callable[[bool], None]] = None, update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None, extra_description: str = ""): """Set up thread-local state and return a contextmanager for managing it. This function is a convenience wrapper. It defines a flag, environment variable, and corresponding thread-local state, which can be managed via the contextmanager it returns. The thread-local state value can be read via the ``config.`` attribute, where ``config`` is the singleton ``Config`` instance. Args: name: string, converted to lowercase to define the name of the config option (and absl flag). It is converted to uppercase to define the corresponding shell environment variable. default: boolean, a default value for the option. help: string, used to populate the flag help information as well as the docstring of the returned context manager. update_global_hook: a optional callback that is called with the updated value of the global state when it is altered or set initially. update_thread_local_hook: a optional callback that is called with the updated value of the thread-local state when it is altered or set initially. extra_description: string, optional: extra information to add to the summary description. Returns: A contextmanager to control the thread-local state value. Example: enable_foo = config.define_bool_state( name='jax_enable_foo', default=False, help='Enable foo.') # Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo # command-line flag can be used to control the process-level value of # the configuration option, in addition to using e.g. # ``config.update("jax_enable_foo", True)`` directly. We can also use a # context manager: with enable_foo(True): ... The value of the thread-local state or flag can be accessed via ``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is an error. """ name = name.lower() self.DEFINE_bool(name, bool_env(name.upper(), default), help, update_hook=update_global_hook) self._contextmanager_flags.add(name) def get_state(self): val = getattr(_thread_local_state, name, unset) return val if val is not unset else self._read(name) setattr(Config, name, property(get_state)) return _StateContextManager(name, help, update_thread_local_hook, extra_description=extra_description) def define_enum_state( self, name: str, enum_values: List[str], default: Optional[str], help: str, update_global_hook: Optional[Callable[[str], None]] = None, update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \ = None): """Set up thread-local state and return a contextmanager for managing it. Args: name: string, converted to lowercase to define the name of the config option (and absl flag). It is converted to uppercase to define the corresponding shell environment variable. enum_values: list of strings representing the possible values for the option. default: optional string, default value. help: string, used to populate the flag help information as well as the docstring of the returned context manager. Returns: A contextmanager to control the thread-local state value. See docstring for ``define_bool_state``. """ name = name.lower() default = os.getenv(name.upper(), default) if default is not None and default not in enum_values: raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}") self.DEFINE_enum(name, default, enum_values=enum_values, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) def get_state(self): val = getattr(_thread_local_state, name, unset) return val if val is not unset else self._read(name) setattr(Config, name, property(get_state)) def validate(new_val): if (new_val is not None and (type(new_val) is not str or new_val not in enum_values)): raise ValueError(f"new enum value must be None or in {enum_values}, " f"got {new_val} of type {type(new_val)}.") return _StateContextManager(name, help, update_thread_local_hook, validate) def define_string_state( self, name: str, default: Optional[str], help: str, update_global_hook: Optional[Callable[[str], None]] = None, update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None): """Set up thread-local state and return a contextmanager for managing it. See docstring for ``define_bool_state``. Args: name: string, converted to lowercase to define the name of the config option (and absl flag). It is converted to uppercase to define the corresponding shell environment variable. default: string, a default value for the option. help: string, used to populate the flag help information as well as the docstring of the returned context manager. update_global_hook: an optional callback that is called with the updated value of the global state when it is altered or set initially. update_thread_local_hook: an optional callback that is called with the updated value of the thread-local state when it is altered or set initially. Returns: A contextmanager to control the thread-local state value. """ name = name.lower() default = os.getenv(name.upper(), default) self.DEFINE_string(name, default, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) def get_state(self): val = getattr(_thread_local_state, name, unset) return val if val is not unset else self._read(name) setattr(Config, name, property(get_state)) def validate(new_val): if new_val is not None and not isinstance(new_val, str): raise ValueError(f"new string config value must be None or of type str," f" got {new_val} of type {type(new_val)}.") return _StateContextManager(name, help, update_thread_local_hook, validate) def _trace_context(self): """Returns a tuple of configuration values that affect tracing. These values are included in the cache key for linear_util.cache. Values included in this set should also most likely be included in the C++ JIT state, which is handled separately.""" return (self.x64_enabled, self.jax_numpy_rank_promotion, self.jax_default_matmul_precision) class _StateContextManager: def __init__(self, name, help, update_thread_local_hook, validate_new_val_hook: Optional[Callable[[Any], None]] = None, extra_description: str = ""): self._name = name self.__name__ = name[4:] if name.startswith('jax_') else name self.__doc__ = f"Context manager for `{name}` config option{extra_description}.\n\n{help}" self._update_thread_local_hook = update_thread_local_hook self._validate_new_val_hook = validate_new_val_hook @contextlib.contextmanager def __call__(self, new_val): if self._validate_new_val_hook: self._validate_new_val_hook(new_val) prev_val = getattr(_thread_local_state, self._name, unset) setattr(_thread_local_state, self._name, new_val) if self._update_thread_local_hook: self._update_thread_local_hook(new_val) try: yield finally: if prev_val is unset: delattr(_thread_local_state, self._name) if self._update_thread_local_hook: self._update_thread_local_hook(None) else: setattr(_thread_local_state, self._name, prev_val) if self._update_thread_local_hook: self._update_thread_local_hook(prev_val) def _add_hooks(self, update_global_hook, update_thread_local_hook): """Private method that adds hooks to an existing context-manager. Used to avoid cyclic import dependencies.""" self._update_thread_local_hook = update_thread_local_hook config._update_hooks[self._name] = update_global_hook update_global_hook(config._read(self._name)) _thread_local_state = threading.local() class _Unset: pass unset = _Unset() class NameSpace: def __init__(self, getter, setter): # must use super because we override this class's __setattr__, see # https://docs.python.org/3/reference/datamodel.html#object.__setattr__ super().__setattr__('_getter', getter) super().__setattr__('_setter', setter) def __getattr__(self, name): return self._getter(name) def __setattr__(self, name, val): self._setter(name, val) config = Config() flags = config FLAGS = flags.FLAGS already_configured_with_absl = False # The C++ JIT maintains its own copy of several configuration items as # a global/thread-local state. These methods allow updates to part of the # state when a configuration value changes. class GlobalJitState(NamedTuple): numpy_rank_promotion: Optional[str] = None default_matmul_precision: Optional[Any] = None def update_global_jit_state(**kw): gs = jax_jit.global_state() context = gs.extra_jit_context or GlobalJitState() gs.extra_jit_context = context._replace(**kw) class ThreadLocalJitState(NamedTuple): dynamic_trace_state: Optional[Any] = None numpy_rank_promotion: Optional[str] = None default_matmul_precision: Optional[Any] = None def update_thread_local_jit_state(**kw): tls = jax_jit.thread_local_state() context = tls.extra_jit_context or ThreadLocalJitState() tls.extra_jit_context = context._replace(**kw) # TODO(mattjj): remove all uses of this flag flags.DEFINE_bool( 'jax_omnistaging', bool_env('JAX_OMNISTAGING', True), help=('Deprecated. Setting this flag to False raises an error. Setting it ' 'to True has no effect.'), ) flags.DEFINE_integer( 'jax_tracer_error_num_traceback_frames', int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), help='Set the number of stack frames in JAX tracer error messages.' ) flags.DEFINE_bool( 'jax_host_callback_inline', bool_env('JAX_HOST_CALLBACK_INLINE', False), help='Inline the host_callback, if not in a staged context.' ) flags.DEFINE_integer( 'jax_host_callback_max_queue_byte_size', int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), help=('The size in bytes of the buffer used to hold outfeeds from each ' 'device. When this capacity is reached consuming outfeeds from the ' 'device is paused, thus potentially pausing the device computation, ' 'until the Python callback consume more outfeeds.'), lower_bound=int(16 * 1e6) ) flags.DEFINE_bool( 'jax_host_callback_outfeed', bool_env('JAX_HOST_CALLBACK_OUTFEED', False), help=( 'Use outfeed implementation for host_callback, even on CPU and GPU. ' 'If false, use the CustomCall implementation. ' 'Has no effect on TPU, since only the outfeed mechanism is implemented.' ) ) flags.DEFINE_bool( 'jax_host_callback_ad_transforms', bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False), help=( 'Enable support for jvp/vjp for the host_callback primitives. Default is ' 'False, which means that host_callback operates only on primals. ' 'The flag exists only temporarily, for backward compatibility.' ) ) enable_checks = config.define_bool_state( name='jax_enable_checks', default=False, help='Turn on invariant checking for JAX internals. Makes things slower.') check_tracer_leaks = config.define_bool_state( name='jax_check_tracer_leaks', default=False, help=('Turn on checking for leaked tracers as soon as a trace completes. ' 'Enabling leak checking may have performance impacts: some caching ' 'is disabled, and other overheads may be added. Additionally, be aware ' 'that some Python debuggers can cause false positives, so it is recommended ' 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) debug_nans = config.define_bool_state( name='jax_debug_nans', default=False, help=('Add nan checks to every operation. When a nan is detected on the ' 'output of a jit-compiled computation, call into the un-compiled ' 'version in an attempt to more precisely identify the operation ' 'which produced the nan.')) debug_infs = config.define_bool_state( name='jax_debug_infs', default=False, help=('Add inf checks to every operation. When an inf is detected on the ' 'output of a jit-compiled computation, call into the un-compiled ' 'version in an attempt to more precisely identify the operation ' 'which produced the inf.')) log_compiles = config.define_bool_state( name='jax_log_compiles', default=False, help=('Log a message each time every time `jit` or `pmap` compiles an XLA ' 'computation. Logging is performed with `absl.logging`. When this ' 'option is set, the log level is WARNING; otherwise the level is ' 'DEBUG.')) parallel_functions_output_gda = config.define_bool_state( name='jax_parallel_functions_output_gda', default=False, help='If True, pjit will output GSDAs.') distributed_debug = config.define_bool_state( name='jax_distributed_debug', default=False, help=('Enable logging useful for debugging multi-process distributed ' 'computations. Logging is performed with `absl.logging` at WARNING ' 'level.')) enable_custom_prng = config.define_bool_state( name='jax_enable_custom_prng', default=False, help=('Enables an internal upgrade that allows one to define custom ' 'pseudo-random number generator implementations. This will ' 'be enabled by default in future versions of JAX, at which point ' 'disabling it will be considered deprecated. In a version ' 'after that the flag will be removed altogether.'), extra_description=" (transient)") default_prng_impl = config.define_enum_state( name='jax_default_prng_impl', enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'], default='threefry2x32', help=('Select the default PRNG implementation, used when one is not ' 'explicitly provided at seeding time.')) hlo_source_file_canonicalization_regex = config.define_string_state( name='jax_hlo_source_file_canonicalization_regex', default=None, help=('Used to canonicalize the source_path metadata of HLO instructions ' 'by removing the given regex. If set, re.sub() is called on each ' 'source_file with the given regex, and all matches are removed. ' 'This can be used to avoid spurious cache misses when using the ' 'persistent compilation cache, which includes HLO metadata in the ' 'cache key.')) config.define_enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], default='64', help=('Specify bit width of default dtypes, either 32-bit or 64-bit. ' 'This is a temporary flag that will be used during the process ' 'of deprecating the ``jax_enable_x64`` flag.')) def _update_x64_global(val): lib.jax_jit.global_state().enable_x64 = val def _update_x64_thread_local(val): lib.jax_jit.thread_local_state().enable_x64 = val enable_x64 = config.define_bool_state( name='jax_enable_x64', default=False, help='Enable 64-bit types to be used', update_global_hook=_update_x64_global, update_thread_local_hook=_update_x64_thread_local) # TODO(phawkins): remove after fixing users of FLAGS.x64_enabled. config._contextmanager_flags.remove("jax_enable_x64") Config.x64_enabled = Config.jax_enable_x64 # type: ignore def _update_disable_jit_global(val): lib.jax_jit.global_state().disable_jit = val def _update_disable_jit_thread_local(val): lib.jax_jit.thread_local_state().disable_jit = val disable_jit = config.define_bool_state( name='jax_disable_jit', default=False, help=('Disable JIT compilation and just call original Python.'), update_global_hook=_update_disable_jit_global, update_thread_local_hook=_update_disable_jit_thread_local) numpy_rank_promotion = config.define_enum_state( name='jax_numpy_rank_promotion', enum_values=['allow', 'warn', 'raise'], default='allow', help=('Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").'), update_global_hook=lambda val: \ update_global_jit_state(numpy_rank_promotion=val), update_thread_local_hook=lambda val: \ update_thread_local_jit_state(numpy_rank_promotion=val)) default_matmul_precision = config.define_enum_state( name='jax_default_matmul_precision', enum_values=['bfloat16', 'tensorfloat32', 'float32'], default=None, help=('Control the default matmul and conv precision for 32bit inputs.\n\n' 'Some platforms, like TPU, offer configurable precision levels for ' 'matrix multiplication and convolution computations, trading off ' 'accuracy for speed. The precision can be controlled for each ' 'operation; for example, see the :func:`jax.lax.conv_general_dilated` ' 'and :func:`jax.lax.dot` docstrings. But it can be useful to control ' 'the default behavior obtained when an operation is not given a ' 'specific precision.\n\n' 'This option can be used to control the default precision ' 'level for computations involved in matrix multiplication and ' 'convolution on 32bit inputs. The levels roughly describe the ' "precision at which scalar products are computed. The 'bfloat16' " "option is the fastest and least precise; 'float32' is similar to " "full float32 precision; 'tensorfloat32' is intermediate.\n\n"), update_global_hook=lambda val: \ update_global_jit_state(default_matmul_precision=val), update_thread_local_hook=lambda val: \ update_thread_local_jit_state(default_matmul_precision=val)) traceback_filtering = config.define_enum_state( name = 'jax_traceback_filtering', enum_values=["off", "tracebackhide", "remove_frames", "auto"], default="auto", help="Controls how JAX filters internal frames out of tracebacks.\n\n" "Valid values are:\n" " * \"off\": disables traceback filtering.\n" " * \"auto\": use \"tracebackhide\" if running under a sufficiently " "new IPython, or \"remove_frames\" otherwise.\n" " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to " " hidden stack frames, which some traceback printers support.\n" " * \"remove_frames\": removes hidden frames from tracebacks, and adds " " the unfiltered traceback as a __cause__ of the exception.\n") enable_mlir = config.define_bool_state( name='jax_enable_mlir', default=False, help=('Enables an experimental code path that compiles JAX programs via ' 'emitting the MLIR MHLO dialect.'))