Migrate a subset of internal modules to use state objects

The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
This commit is contained in:
Sergei Lebedev 2023-10-09 07:28:18 -07:00 committed by jax authors
parent 14414363d0
commit 65d3058944
27 changed files with 232 additions and 212 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
from jax._src import ad_util from jax._src import ad_util
from jax._src import api from jax._src import api
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -31,7 +32,6 @@ from jax._src import source_info_util
from jax._src import traceback_util from jax._src import traceback_util
from jax._src import util from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
@ -529,7 +529,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if config.jax_log_checkpoint_residuals logger.log(logging.WARNING if config.log_checkpoint_residuals.value
else logging.DEBUG, else logging.DEBUG,
'remat-decorated function ' + 'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) + 'saving inputs with shapes:\n' * bool(res_invars) +
@ -659,7 +659,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
assert not jaxpr.constvars assert not jaxpr.constvars
if differentiated and prevent_cse: if differentiated and prevent_cse:
if config.jax_remat_opt_barrier: if config.remat_opt_barrier.value:
translation_rule = _remat_translation_using_opt_barrier translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform: elif is_gpu_platform:
translation_rule = _remat_translation_using_while translation_rule = _remat_translation_using_while

View File

@ -27,13 +27,13 @@ from jax._src import abstract_arrays
from jax._src import api from jax._src import api
from jax._src import api_util from jax._src import api_util
from jax._src import basearray from jax._src import basearray
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import profiler from jax._src import profiler
from jax._src import tree_util from jax._src import tree_util
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.config import config
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
@ -172,7 +172,7 @@ class ArrayImpl(basearray.Array):
# input buffers are already arranged properly. This usually happens when # input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation # Array's are created as output of a JAX transformation
# (like pjit, xmap, etc). # (like pjit, xmap, etc).
if not _skip_checks or config.jax_enable_checks: if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange() self._check_and_rearrange()
def _check_and_rearrange(self): def _check_and_rearrange(self):

View File

@ -20,7 +20,7 @@ import os
import struct import struct
import sys import sys
from jax._src.config import config from jax._src import config
from jax._src.lib import version as jaxlib_version from jax._src.lib import version as jaxlib_version
from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client from jax._src.lib import xla_client
@ -167,7 +167,7 @@ def _canonicalize_ir(m_original: ir.Module) -> bytes:
def _hash_computation(hash_obj, module): def _hash_computation(hash_obj, module):
if config.jax_compilation_cache_include_metadata_in_key: if config.compilation_cache_include_metadata_in_key.value:
canonical_ir = _serialize_ir(module) canonical_ir = _serialize_ir(module)
else: else:
canonical_ir = _canonicalize_ir(module) canonical_ir = _canonicalize_ir(module)

View File

@ -27,6 +27,7 @@ from jax import lax
from jax._src import api from jax._src import api
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import custom_derivatives from jax._src import custom_derivatives
from jax._src import effects from jax._src import effects
@ -37,7 +38,6 @@ from jax._src import traceback_util
from jax._src import tree_util as jtu from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun from jax._src.api_util import flatten_fun
from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
@ -511,7 +511,7 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
if debug: if debug:
# NOOP (check will only trigger when discharged) # NOOP (check will only trigger when discharged)
return [] return []
if not config.jax_experimental_unsafe_xla_runtime_errors: if not config.xla_runtime_errors.value:
raise functionalization_error raise functionalization_error
out_op, _, _ = mlir.emit_python_callback( out_op, _, _ = mlir.emit_python_callback(

View File

@ -30,32 +30,31 @@ import numpy as np
from jax._src import lib from jax._src import lib
from jax._src import compilation_cache from jax._src import compilation_cache
from jax._src import config as jax_config from jax._src import config as config
from jax._src import monitoring from jax._src import monitoring
from jax._src import path from jax._src import path
from jax._src import profiler from jax._src import profiler
from jax._src import traceback_util from jax._src import traceback_util
from jax._src.config import config
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension_version
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool( _DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool(
'jax_disable_most_optimizations', 'jax_disable_most_optimizations',
jax_config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of ' 'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.') 'optimization is greater than that of running a less-optimized program.')
_DUMP_IR_TO = jax_config.DEFINE_string( _DUMP_IR_TO = config.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX as input to the " help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX " "compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.") "will not dump IR.")
_COMPILER_DETAILED_LOGGING_MIN_OPS = jax_config.DEFINE_integer( _COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
"jax_compiler_detailed_logging_min_ops", "jax_compiler_detailed_logging_min_ops",
jax_config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10), config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
help=( help=(
'How big should a module be in MLIR operations before JAX enables ' 'How big should a module be in MLIR operations before JAX enables '
'detailed compiler logging? The intent of this flag is to suppress ' 'detailed compiler logging? The intent of this flag is to suppress '
@ -192,7 +191,7 @@ def get_compile_options(
# If the function returns 0, set -1; this is an error. # If the function returns 0, set -1; this is an error.
# -1 indicates that no attempt should be made to retrieve the latest profile # -1 indicates that no attempt should be made to retrieve the latest profile
# later on. # later on.
jax_xla_profile_version = config.jax_xla_profile_version jax_xla_profile_version = config.jax_xla_profile_version.value
if jax_xla_profile_version > 0: if jax_xla_profile_version > 0:
compile_options.profile_version = jax_xla_profile_version compile_options.profile_version = jax_xla_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " + logger.debug("get_compile_options XLA-AutoFDO profile: " +
@ -306,7 +305,7 @@ def compile_or_get_cached(
try: try:
cache_key = compilation_cache.get_cache_key( cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend, computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation, config.config.jax_use_original_compilation_cache_key_generation,
) )
except xc._xla.XlaRuntimeError as ex: except xc._xla.XlaRuntimeError as ex:
logger.error("compile_or_get_cached: unable to generate cache key, " logger.error("compile_or_get_cached: unable to generate cache key, "
@ -325,7 +324,7 @@ def compile_or_get_cached(
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit # TODO(b/293308239) Instrument metrics for new cache savings and cache hit
# rate after it is enabled. # rate after it is enabled.
if jax_config.config.jax_use_original_compilation_cache_key_generation: if config.config.jax_use_original_compilation_cache_key_generation:
# TODO(b/293308239) Remove metrics for the original cache after the new # TODO(b/293308239) Remove metrics for the original cache after the new
# compilation cache key implementation is fully rolled out. # compilation cache key implementation is fully rolled out.
monitoring.record_event('/jax/compilation_cache/cache_hits_original') monitoring.record_event('/jax/compilation_cache/cache_hits_original')
@ -358,7 +357,7 @@ def _cache_read(
return compilation_cache.get_executable_and_time( return compilation_cache.get_executable_and_time(
cache_key, compile_options, backend) cache_key, compile_options, backend)
except Exception as ex: except Exception as ex:
if config.jax_raise_persistent_cache_errors: if config.raise_persistent_cache_errors.value:
raise raise
warnings.warn( warnings.warn(
f"Error reading persistent compilation cache entry for " f"Error reading persistent compilation cache entry for "
@ -380,7 +379,7 @@ def _cache_write(cache_key: str,
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name) "callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
return return
min_compile_time = config.jax_persistent_cache_min_compile_time_secs min_compile_time = config.persistent_cache_min_compile_time_secs.value
if min_compile_time: if min_compile_time:
if compile_time_secs < min_compile_time: if compile_time_secs < min_compile_time:
logger.debug( logger.debug(
@ -399,7 +398,7 @@ def _cache_write(cache_key: str,
compilation_cache.put_executable_and_time( compilation_cache.put_executable_and_time(
cache_key, module_name, executable, backend, int(compile_time_secs)) cache_key, module_name, executable, backend, int(compile_time_secs))
except Exception as ex: except Exception as ex:
if config.jax_raise_persistent_cache_errors: if config.raise_persistent_cache_errors.value:
raise raise
warnings.warn( warnings.warn(
f"Error writing persistent compilation cache entry for " f"Error writing persistent compilation cache entry for "

View File

@ -22,7 +22,7 @@ import logging
import os import os
import sys import sys
import threading import threading
from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
from jax._src import lib from jax._src import lib
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
@ -75,6 +75,12 @@ class FlagHolder(Generic[_T]):
self._flags = flags self._flags = flags
self._name = name self._name = name
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
@property @property
def value(self) -> _T: def value(self) -> _T:
return getattr(self._flags, self._name) return getattr(self._flags, self._name)
@ -523,6 +529,12 @@ class _StateContextManager(Generic[_T]):
self._validate_new_val_hook = validate_new_val_hook self._validate_new_val_hook = validate_new_val_hook
self._default_value = default_value self._default_value = default_value
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
@property @property
def value(self) -> _T: def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset) val = _thread_local_state.__dict__.get(self._name, unset)
@ -935,7 +947,7 @@ use_original_compilation_cache_key_generation = config.define_bool_state(
"deployed, this flag and the original cache-key generation algorithm " "deployed, this flag and the original cache-key generation algorithm "
"will be removed.") "will be removed.")
config.define_enum_state( default_dtype_bits = config.define_enum_state(
name='jax_default_dtype_bits', name='jax_default_dtype_bits',
enum_values=['32', '64'], enum_values=['32', '64'],
default='64', default='64',
@ -1090,7 +1102,7 @@ bcoo_cusparse_lowering = config.define_bool_state(
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging # TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result # if the intended backend can handle lowering the result
config.define_bool_state( dynamic_shapes = config.define_bool_state(
name='jax_dynamic_shapes', name='jax_dynamic_shapes',
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
help=('Enables experimental features for staging out computations with ' help=('Enables experimental features for staging out computations with '
@ -1102,19 +1114,19 @@ config.define_bool_state(
# This flag is temporary during rollout of the remat barrier. # This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints. # TODO(parkers): Remove if there are no complaints.
config.define_bool_state( remat_opt_barrier = config.define_bool_state(
name='jax_remat_opt_barrier', name='jax_remat_opt_barrier',
default=(lib.version >= (0, 3, 6)), default=(lib.version >= (0, 3, 6)),
help=('Enables using optimization-barrier op for lowering remat.')) help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(sharadmv,mattjj): set default to True, then remove # TODO(sharadmv,mattjj): set default to True, then remove
config.define_bool_state( eager_pmap = config.define_bool_state(
name='jax_eager_pmap', name='jax_eager_pmap',
default=True, default=True,
upgrade=True, upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.') help='Enable eager-mode pmap when jax_disable_jit is activated.')
config.define_bool_state( xla_runtime_errors = config.define_bool_state(
name='jax_experimental_unsafe_xla_runtime_errors', name='jax_experimental_unsafe_xla_runtime_errors',
default=False, default=False,
help=('Enable XLA runtime errors for jax.experimental.checkify.checks ' help=('Enable XLA runtime errors for jax.experimental.checkify.checks '

View File

@ -36,9 +36,8 @@ from weakref import ref
import numpy as np import numpy as np
from jax._src import dtypes from jax._src import dtypes
from jax._src import config as jax_config from jax._src import config
from jax._src import effects from jax._src import effects
from jax._src.config import config
from jax._src.errors import ( from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError) TracerIntegerConversionError, UnexpectedTracerError)
@ -60,9 +59,9 @@ zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer( _TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer(
'jax_tracer_error_num_traceback_frames', 'jax_tracer_error_num_traceback_frames',
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.' help='Set the number of stack frames in JAX tracer error messages.'
) )
@ -269,7 +268,7 @@ class JaxprEqn(NamedTuple):
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns # TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None): def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
source_info = source_info or source_info_util.new_source_info() source_info = source_info or source_info_util.new_source_info()
if config.jax_enable_checks: if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars) assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info) return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
@ -381,7 +380,7 @@ class Primitive:
return f'{self.name}' return f'{self.name}'
def bind(self, *args, **params): def bind(self, *args, **params):
assert (not config.jax_enable_checks or assert (not config.enable_checks.value or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
return self.bind_with_trace(find_top_trace(args), args, params) return self.bind_with_trace(find_top_trace(args), args, params)
@ -438,7 +437,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
return v.val if isinstance(v, Literal) else env[v] return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None: def write(v: Var, val: Any) -> None:
if config.jax_enable_checks and not config.jax_dynamic_shapes: if config.enable_checks.value and not config.dynamic_shapes.value:
assert typecheck(v.aval, val), (v.aval, val) assert typecheck(v.aval, val), (v.aval, val)
env[v] = val env[v] = val
@ -739,7 +738,7 @@ class Tracer(typing.Array):
def __getattr__(self, name): def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here # if the aval property raises an AttributeError, gets caught here
assert not config.jax_enable_checks or name != "aval" assert not config.enable_checks.value or name != "aval"
try: try:
attr = getattr(self.aval, name) attr = getattr(self.aval, name)
@ -989,7 +988,7 @@ def _update_thread_local_jit_state(dynamic):
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept # TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
# alive by the JIT cache, particularly for nested JIT-ted functions. # alive by the JIT cache, particularly for nested JIT-ted functions.
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload) copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy) config.update_thread_local_jit_state(dynamic_trace_state=copy)
# The global state of the tracer is accessed by a thread-local object. # The global state of the tracer is accessed by a thread-local object.
@ -1015,7 +1014,7 @@ def _initialize_jax_jit_thread_local_state():
if tls.extra_jit_context is None: if tls.extra_jit_context is None:
dynamic = thread_local_state.trace_state.trace_stack.dynamic dynamic = thread_local_state.trace_state.trace_stack.dynamic
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload) copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy) config.update_thread_local_jit_state(dynamic_trace_state=copy)
jax_jit.set_thread_local_state_initialization_callback( jax_jit.set_thread_local_state_initialization_callback(
@ -1151,7 +1150,7 @@ def new_main(trace_type: type[Trace], dynamic: bool = False,
stack.dynamic = prev_dynamic stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic) _update_thread_local_jit_state(stack.dynamic)
if config.jax_check_tracer_leaks: if config.check_tracer_leaks.value:
t = ref(main) t = ref(main)
del main del main
if t() is not None: if t() is not None:
@ -1188,7 +1187,7 @@ def new_base_main(trace_type: type[Trace],
stack.stack[0] = prev_base stack.stack[0] = prev_base
_update_thread_local_jit_state(stack.dynamic) _update_thread_local_jit_state(stack.dynamic)
if config.jax_check_tracer_leaks: if config.check_tracer_leaks.value:
t = ref(main) t = ref(main)
del main del main
if t() is not None: if t() is not None:
@ -1268,7 +1267,7 @@ def new_sublevel() -> Generator[None, None, None]:
finally: finally:
thread_local_state.trace_state.substack.pop() thread_local_state.trace_state.substack.pop()
if config.jax_check_tracer_leaks: if config.check_tracer_leaks.value:
t = ref(sublevel) t = ref(sublevel)
del sublevel del sublevel
if t() is not None: if t() is not None:
@ -2026,9 +2025,9 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
return operator.index(dim) return operator.index(dim)
except TypeError as e: except TypeError as e:
type_error = e type_error = e
if isinstance(dim, Tracer) and config.jax_dynamic_shapes: if isinstance(dim, Tracer) and config.dynamic_shapes.value:
return dim return dim
elif (config.jax_dynamic_shapes and isinstance(dim, DArray) and elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape): type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim return dim
elif is_dim(dim): elif is_dim(dim):
@ -2230,7 +2229,7 @@ class CallPrimitive(Primitive):
new_params = dict(params) new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr') jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ()) subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr)) subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params return [subfun], new_params
@ -2458,14 +2457,14 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag) frame = AxisEnvFrame(axis_name, size, tag)
ts = thread_local_state.trace_state ts = thread_local_state.trace_state
ts.axis_env.append(frame) ts.axis_env.append(frame)
jax_config.update_thread_local_jit_state( config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name)) if f.name is not no_axis_name))
try: try:
yield yield
finally: finally:
ts.axis_env.pop() ts.axis_env.pop()
jax_config.update_thread_local_jit_state( config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name)) if f.name is not no_axis_name))
@ -2474,14 +2473,14 @@ def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes] frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
ts = thread_local_state.trace_state ts = thread_local_state.trace_state
ts.axis_env.extend(frames) ts.axis_env.extend(frames)
jax_config.update_thread_local_jit_state( config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name)) if f.name is not no_axis_name))
try: try:
yield yield
finally: finally:
for _ in frames: ts.axis_env.pop() for _ in frames: ts.axis_env.pop()
jax_config.update_thread_local_jit_state( config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name)) if f.name is not no_axis_name))
@ -2493,12 +2492,12 @@ def stash_axis_env():
# be raised. # be raised.
ts = thread_local_state.trace_state ts = thread_local_state.trace_state
prev_axis_env, ts.axis_env = ts.axis_env, [] prev_axis_env, ts.axis_env = ts.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=()) config.update_thread_local_jit_state(axis_env_state=())
try: try:
yield yield
finally: finally:
ts.axis_env = prev_axis_env ts.axis_env = prev_axis_env
jax_config.update_thread_local_jit_state( config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name)) if f.name is not no_axis_name))

View File

@ -18,6 +18,7 @@ from functools import update_wrapper, reduce, partial
import inspect import inspect
from typing import Any, Callable, Generic, Optional, TypeVar from typing import Any, Callable, Generic, Optional, TypeVar
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import custom_api_util from jax._src import custom_api_util
from jax._src.custom_transpose import custom_transpose from jax._src.custom_transpose import custom_transpose
@ -28,7 +29,6 @@ from jax._src import traceback_util
from jax._src.ad_util import ( from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.config import config
from jax._src.core import raise_to_shaped from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad from jax._src.interpreters import ad
@ -590,7 +590,7 @@ class custom_vjp(Generic[ReturnValue]):
raise AttributeError(msg) raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs) args = _resolve_kwargs(self.fun, args, kwargs)
if config.jax_enable_custom_vjp_by_custom_transpose: if config.enable_custom_vjp_by_custom_transpose.value:
if self.nondiff_argnums: if self.nondiff_argnums:
raise NotImplementedError( raise NotImplementedError(
'nondiff_argnums not implemented for new custom_vjp') 'nondiff_argnums not implemented for new custom_vjp')
@ -1072,7 +1072,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
""" """
flat_args, in_tree = tree_flatten(example_args) flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args)) in_avals = tuple(map(abstractify, flat_args))
if config.jax_check_tracer_leaks: if config.check_tracer_leaks.value:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals) return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else: else:
return _closure_convert_for_avals(fun, in_tree, in_avals) return _closure_convert_for_avals(fun, in_tree, in_avals)

View File

@ -30,6 +30,7 @@ import numpy as np
import jax import jax
from jax._src import basearray from jax._src import basearray
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import dtypes from jax._src import dtypes
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -39,7 +40,6 @@ from jax._src import source_info_util
from jax._src import traceback_util from jax._src import traceback_util
from jax._src import util from jax._src import util
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
@ -160,7 +160,7 @@ def xla_primitive_callable(
lowering_parameters=mlir.LoweringParameters()) lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile() compiled = computation.compile()
if xla_extension_version >= 192: if xla_extension_version >= 192:
if config.jax_disable_jit: if config.disable_jit.value:
call = compiled.unsafe_call call = compiled.unsafe_call
else: else:
call = compiled.create_cpp_call_for_apply_primitive(out_tree()) call = compiled.create_cpp_call_for_apply_primitive(out_tree())
@ -262,7 +262,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
if _on_exit: if _on_exit:
yield yield
else: else:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
start_time = time.time() start_time = time.time()
yield yield
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@ -395,7 +395,7 @@ def _initial_style_primitive_replicas(params: dict[str, Any]) -> int:
default=1) default=1)
def needs_check_special() -> bool: def needs_check_special() -> bool:
return config.jax_debug_infs or config.jax_debug_nans return config.debug_infs.value or config.debug_nans.value
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None: def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
if needs_check_special(): if needs_check_special():
@ -404,9 +404,9 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact): if dtypes.issubdtype(dtype, np.inexact):
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))): if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}") raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))): if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}") raise FloatingPointError(f"invalid value (inf) encountered in {name}")

View File

@ -19,7 +19,7 @@ import os
from typing import Any, Optional, Union from typing import Any, Optional, Union
from jax._src import clusters from jax._src import clusters
from jax._src.config import config from jax._src import config
from jax._src.lib import xla_extension from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension_version
@ -63,8 +63,8 @@ class State:
if local_device_ids: if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr] visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logger.info('JAX distributed initialized with visible devices: %s', visible_devices) logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
config.update("jax_cuda_visible_devices", visible_devices) config.config.update("jax_cuda_visible_devices", visible_devices)
config.update("jax_rocm_visible_devices", visible_devices) config.config.update("jax_rocm_visible_devices", visible_devices)
self.process_id = process_id self.process_id = process_id

View File

@ -30,7 +30,7 @@ import warnings
import ml_dtypes import ml_dtypes
import numpy as np import numpy as np
from jax._src.config import config from jax._src import config
from jax._src.typing import DType, DTypeLike from jax._src.typing import DType, DTypeLike
from jax._src import traceback_util from jax._src import traceback_util
@ -59,7 +59,6 @@ class extended(np.generic):
>>> jnp.issubdtype(key.dtype, dtypes.extended) >>> jnp.issubdtype(key.dtype, dtypes.extended)
True True
""" """
pass
class prng_key(extended): class prng_key(extended):
@ -75,7 +74,6 @@ class prng_key(extended):
>>> jnp.issubdtype(key.dtype, dtypes.prng_key) >>> jnp.issubdtype(key.dtype, dtypes.prng_key)
True True
""" """
pass
class ExtendedDType(metaclass=abc.ABCMeta): class ExtendedDType(metaclass=abc.ABCMeta):
@ -139,12 +137,28 @@ _int4_dtypes = [
] ]
# Default types. # Default types.
bool_: type = np.bool_ bool_ = np.bool_
int_: type = np.int32 if config.jax_default_dtype_bits == '32' else np.int64 int_: type[Any]
uint: type = np.uint32 if config.jax_default_dtype_bits == '32' else np.uint64 uint: type[Any]
float_: type = np.float32 if config.jax_default_dtype_bits == '32' else np.float64 float_: type[Any]
complex_: type = np.complex64 if config.jax_default_dtype_bits == '32' else np.complex128 complex_: type[Any]
_default_types: dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_} if config.default_dtype_bits.value == '32':
int_ = np.int32
uint = np.uint32
float_ = np.float32
complex_ = np.complex64
else:
int_ = np.int64
uint = np.uint64
float_ = np.float64
complex_ = np.complex128
_default_types: dict[str, type[Any]] = {
'b': bool_,
'i': int_,
'u': uint,
'f': float_,
'c': complex_,
}
# Trivial vectorspace datatype needed for tangent values of int/bool primals # Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)]) float0: np.dtype = np.dtype([('float0', np.void, 0)])
@ -219,7 +233,7 @@ def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opa
"allow_opaque_dtype argument is deprecated; use allow_extended_dtype.", "allow_opaque_dtype argument is deprecated; use allow_extended_dtype.",
DeprecationWarning) DeprecationWarning)
allow_extended_dtype = allow_opaque_dtype allow_extended_dtype = allow_opaque_dtype
return _canonicalize_dtype(config.x64_enabled, allow_extended_dtype, dtype) # type: ignore[bad-return-type] return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # type: ignore[bad-return-type]
# Default dtypes corresponding to Python scalars. # Default dtypes corresponding to Python scalars.
python_scalar_dtypes : dict[type, DType] = { python_scalar_dtypes : dict[type, DType] = {
@ -507,7 +521,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
if len(LUB) == 1: if len(LUB) == 1:
return LUB.pop() return LUB.pop()
elif len(LUB) == 0: elif len(LUB) == 0:
if config.jax_numpy_dtype_promotion == 'strict': if config.numpy_dtype_promotion.value == 'strict':
msg = ( msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype " f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting " "promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting "
@ -553,7 +567,7 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
# object identity, not object equality, due to the behavior of np.dtype.__eq__ # object identity, not object equality, due to the behavior of np.dtype.__eq__
a_tp = cast(JAXType, a if any(a is t for t in _weak_types) else np.dtype(a)) a_tp = cast(JAXType, a if any(a is t for t in _weak_types) else np.dtype(a))
b_tp = cast(JAXType, b if any(b is t for t in _weak_types) else np.dtype(b)) b_tp = cast(JAXType, b if any(b is t for t in _weak_types) else np.dtype(b))
return np.dtype(_least_upper_bound(config.jax_numpy_dtype_promotion, a_tp, b_tp)) return np.dtype(_least_upper_bound(config.numpy_dtype_promotion.value, a_tp, b_tp))
def is_weakly_typed(x: Any) -> bool: def is_weakly_typed(x: Any) -> bool:
try: try:
@ -604,17 +618,17 @@ def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
# Trivial promotion case. This allows extended dtypes through. # Trivial promotion case. This allows extended dtypes through.
out_dtype = dtypes[0] out_dtype = dtypes[0]
out_weak_type = False out_weak_type = False
elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict': elif all(weak_types) and config.numpy_dtype_promotion.value != 'strict':
# If all inputs are weakly typed, we compute the bound of the strongly-typed # If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the # counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16). # incorrect result with non-canonical weak types (e.g. weak int16).
# TODO(jakevdp): explore removing this special case. # TODO(jakevdp): explore removing this special case.
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion, result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
*{_jax_type(dtype, False) for dtype in dtypes}) *{_jax_type(dtype, False) for dtype in dtypes})
out_dtype = dtype(result_type) out_dtype = dtype(result_type)
out_weak_type = True out_weak_type = True
else: else:
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion, result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)}) *{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
out_dtype = dtype(result_type) out_dtype = dtype(result_type)
out_weak_type = any(result_type is t for t in _weak_types) out_weak_type = any(result_type is t for t in _weak_types)

View File

@ -29,6 +29,7 @@ from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
import warnings import warnings
from jax._src import ad_util from jax._src import ad_util
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import dtypes from jax._src import dtypes
from jax._src import effects as effects_lib from jax._src import effects as effects_lib
@ -38,7 +39,6 @@ from jax._src import sharding_impls
from jax._src import source_info_util from jax._src import source_info_util
from jax._src import util from jax._src import util
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
@ -316,9 +316,8 @@ register_constant_handler(core.Token, _token_constant_handler)
def get_canonical_source_file(frame: source_info_util.Frame) -> str: def get_canonical_source_file(frame: source_info_util.Frame) -> str:
source_file = frame.file_name source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex: if pattern := config.hlo_source_file_canonicalization_regex.value:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex, source_file = re.sub(pattern, '', source_file)
'', source_file)
return source_file return source_file
def _traceback_to_location(tb: xc.Traceback) -> ir.Location: def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
@ -349,7 +348,7 @@ def _source_info_to_location(
name_stack: source_info_util.NameStack) -> ir.Location: name_stack: source_info_util.NameStack) -> ir.Location:
eqn_str = (f'{str(source_info.name_stack)}/' eqn_str = (f'{str(source_info.name_stack)}/'
f'{core.str_eqn_compact(primitive.name, params)}') f'{core.str_eqn_compact(primitive.name, params)}')
if config.jax_include_full_tracebacks_in_locations: if config.include_full_tracebacks_in_locations.value:
if source_info.traceback is None: if source_info.traceback is None:
loc = ir.Location.unknown() loc = ir.Location.unknown()
else: else:
@ -622,7 +621,7 @@ def sharded_aval(aval: core.AbstractValue,
def eval_dynamic_shape(ctx: LoweringRuleContext, def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> tuple[int | Value, ...]: shape: core.Shape) -> tuple[int | Value, ...]:
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else: else:
ctx = ctx.replace( ctx = ctx.replace(
@ -774,7 +773,7 @@ def lower_jaxpr_to_module(
host_callbacks: list[Any] = [] host_callbacks: list[Any] = []
dim_vars: Sequence[str] dim_vars: Sequence[str]
if not config.jax_dynamic_shapes: if not config.dynamic_shapes.value:
# Find the dimension variables # Find the dimension variables
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape") all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
for d in aval.shape if not core.is_constant_dim(d)] for d in aval.shape if not core.is_constant_dim(d)]
@ -1418,7 +1417,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in, module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values) tokens_out=None, dim_var_values=dim_var_values)
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
axis_size_env = {d: read(d)[0] axis_size_env = {d: read(d)[0]
for a in avals_in if type(a) is core.DShapedArray for a in avals_in if type(a) is core.DShapedArray
for d in a.shape if type(d) is core.Var} for d in a.shape if type(d) is core.Var}
@ -1589,7 +1588,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params) wrapped_fun = lu.wrap_init(f, params)
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# We might be applying this function to arguments with dynamic shapes, # We might be applying this function to arguments with dynamic shapes,
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
# case, we need to form a jaxpr with leading binders for those axis size # case, we need to form a jaxpr with leading binders for those axis size

View File

@ -26,13 +26,13 @@ from weakref import ref
import numpy as np import numpy as np
from jax._src import linear_util as lu
from jax._src.config import config
from jax._src import ad_util from jax._src import ad_util
from jax._src import api_util from jax._src import api_util
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import effects
from jax._src import dtypes from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import profiler from jax._src import profiler
from jax._src import source_info_util from jax._src import source_info_util
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
@ -106,7 +106,7 @@ class PartialVal(tuple):
""" """
def __new__(cls, xs: tuple[AbstractValue | None, core.Value]): def __new__(cls, xs: tuple[AbstractValue | None, core.Value]):
pv, const = xs pv, const = xs
if config.jax_enable_checks: if config.enable_checks.value:
# type checks # type checks
assert isinstance(pv, (AbstractValue, type(None))), xs assert isinstance(pv, (AbstractValue, type(None))), xs
assert (const is None or isinstance(const, core.Tracer) or assert (const is None or isinstance(const, core.Tracer) or
@ -255,7 +255,7 @@ class JaxprTrace(Trace['JaxprTracer']):
# which were unknown to the first call (corresponding to in_avals). # which were unknown to the first call (corresponding to in_avals).
# Wrap f to perform the partial evaluation and plumb out aux data. # Wrap f to perform the partial evaluation and plumb out aux data.
if not config.jax_dynamic_shapes: if not config.dynamic_shapes.value:
f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
tuple(in_avals)) tuple(in_avals))
@ -275,7 +275,7 @@ class JaxprTrace(Trace['JaxprTracer']):
out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)]) out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)])
# Form the complete list of residuals by forwarding some inputs. # Form the complete list of residuals by forwarding some inputs.
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# With dynamic shapes, we may need to forward implicit arguments. # With dynamic shapes, we may need to forward implicit arguments.
in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns) in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns)
in_consts_full = [None] * len(f.in_type) in_consts_full = [None] * len(f.in_type)
@ -303,7 +303,7 @@ class JaxprTrace(Trace['JaxprTracer']):
staged_params = update_params(staged_params, map(op.not_, in_knowns), staged_params = update_params(staged_params, map(op.not_, in_knowns),
num_new_args) num_new_args)
# The outputs of the staged-out call are Tracers with the new eqn as recipe. # The outputs of the staged-out call are Tracers with the new eqn as recipe.
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# With dynamic shapes, we may need to substitute Tracers into avals. # With dynamic shapes, we may need to substitute Tracers into avals.
out_tracers = [] out_tracers = []
for aval, _ in out_type: for aval, _ in out_type:
@ -958,21 +958,21 @@ def tracers_to_jaxpr(
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type] jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
outvars, eqns, jaxpr_effects) outvars, eqns, jaxpr_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently! # del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals return jaxpr, const_vals, env_vals
@weakref_lru_cache @weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars.""" """Moves the constvars to the start of invars."""
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
dbg = jaxpr.debug_info and jaxpr.debug_info._replace( dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
lifted_jaxpr = Jaxpr(constvars=(), lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars, invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects, debug_info=dbg) effects=jaxpr.effects, debug_info=dbg)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr return lifted_jaxpr
@weakref_lru_cache @weakref_lru_cache
@ -980,24 +980,24 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr.""" """Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
raise NotImplementedError raise NotImplementedError
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n]) constvars, invars = split_list(jaxpr.invars, [n])
dbg = jaxpr.debug_info and jaxpr.debug_info._replace( dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names[n:]) arg_names=jaxpr.debug_info.arg_names[n:])
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars, lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
debug_info=dbg) debug_info=dbg)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr return lifted_jaxpr
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
raise NotImplementedError raise NotImplementedError
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects) effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr) config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr return converted_jaxpr
@ -1090,7 +1090,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
# check jaxpr_known and jaxpr_unknown in isolation # check jaxpr_known and jaxpr_unknown in isolation
# TODO(mattjj): enable weak type checking here # TODO(mattjj): enable weak type checking here
if config.jax_enable_checks: if config.enable_checks.value:
core.check_jaxpr(jaxpr_known) core.check_jaxpr(jaxpr_known)
core.check_jaxpr(jaxpr_unknown) core.check_jaxpr(jaxpr_unknown)
# check jaxpr_known has input type corresponding to known inputs of jaxpr # check jaxpr_known has input type corresponding to known inputs of jaxpr
@ -1256,7 +1256,7 @@ def _partial_eval_jaxpr_custom_cached(
known_outvars, known_eqns) known_outvars, known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars, jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
known_eqns, known_effects) known_eqns, known_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_known) config.enable_checks.value and core.check_jaxpr(jaxpr_known)
_, ins_staged = partition_list(in_inst, jaxpr.invars) _, ins_staged = partition_list(in_inst, jaxpr.invars)
_, outs_staged = partition_list(out_inst, jaxpr.outvars) _, outs_staged = partition_list(out_inst, jaxpr.outvars)
@ -1265,7 +1265,7 @@ def _partial_eval_jaxpr_custom_cached(
outs_staged, staged_eqns) outs_staged, staged_eqns)
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns, staged_effects) outs_staged, staged_eqns, staged_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged) config.enable_checks.value and core.check_jaxpr(jaxpr_staged)
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
len(non_input_res_refs)) len(non_input_res_refs))
@ -1486,7 +1486,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b), tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.jax_enable_checks and core.check_jaxpr(new_jaxpr) config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs return new_jaxpr, used_inputs
@ -1702,7 +1702,7 @@ class JaxprStackFrame:
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals)
jaxpr, out_type = _add_implicit_outputs(jaxpr) jaxpr, out_type = _add_implicit_outputs(jaxpr)
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, out_type, constvals return jaxpr, out_type, constvals
def newvar(self, aval): def newvar(self, aval):
@ -2226,7 +2226,7 @@ def trace_to_subjaxpr_dynamic(
out_tracers = map(trace.full_raise, ans) out_tracers = map(trace.full_raise, ans)
jaxpr, consts = frame.to_jaxpr(out_tracers) jaxpr, consts = frame.to_jaxpr(out_tracers)
del fun, main, trace, frame, in_tracers, out_tracers, ans del fun, main, trace, frame, in_tracers, out_tracers, ans
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts return jaxpr, [v.aval for v in jaxpr.outvars], consts
@ -2432,7 +2432,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
jaxpr.effects, jaxpr.debug_info) jaxpr.effects, jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr)
return new_jaxpr, out_type return new_jaxpr, out_type

View File

@ -34,8 +34,9 @@ import jax
from jax.errors import JAXTypeError from jax.errors import JAXTypeError
from jax._src import api_util from jax._src import api_util
from jax._src import core
from jax._src import compiler from jax._src import compiler
from jax._src import config
from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import effects from jax._src import effects
@ -51,7 +52,6 @@ from jax._src import tree_util
from jax._src import util from jax._src import util
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import DShapedArray from jax._src.core import DShapedArray
from jax._src.core import ShapedArray from jax._src.core import ShapedArray
from jax._src.interpreters import ad from jax._src.interpreters import ad
@ -279,7 +279,7 @@ def xla_pmap_impl_lazy(
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool, is_explicit_global_axis_size: bool,
) -> Callable: ) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and if (config.disable_jit.value and config.eager_pmap.value and
not is_explicit_global_axis_size and not any(d for d in donated_invars)): not is_explicit_global_axis_size and not any(d for d in donated_invars)):
def _emap_apply_fn(*args): def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
@ -296,7 +296,7 @@ def xla_pmap_impl_lazy(
is_explicit_global_axis_size, *abstract_args) is_explicit_global_axis_size, *abstract_args)
# Don't re-abstractify args unless logging is enabled for performance. # Don't re-abstractify args unless logging is enabled for performance.
if config.jax_distributed_debug: if config.distributed_debug.value:
distributed_debug_log(("Running pmapped function", name), distributed_debug_log(("Running pmapped function", name),
("python function", fun.f), ("python function", fun.f),
("devices", devices), ("devices", devices),
@ -433,7 +433,7 @@ class MapTrace(core.Trace):
def process_map(self, map_primitive, fun, tracers, params): def process_map(self, map_primitive, fun, tracers, params):
if params['devices'] is not None: if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.") raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit: if not config.disable_jit.value:
bind = HashableFunction( bind = HashableFunction(
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs), lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
(map_primitive, fun)) (map_primitive, fun))
@ -728,7 +728,7 @@ def lower_parallel_callable(
f"`axis_size` (or remove the `devices` argument). Got nested_replicas=" f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
f"{replicas.jaxpr_replicas}") f"{replicas.jaxpr_replicas}")
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority): if logger.isEnabledFor(log_priority):
logger.log(log_priority, logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)", "Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
@ -1630,7 +1630,7 @@ ShardingInfo = tuple[
def _get_default_device() -> xc.Device: def _get_default_device() -> xc.Device:
return config.jax_default_device or xb.local_devices()[0] return config.default_device.value or xb.local_devices()[0]
class _thread_local_decorator(threading.local): class _thread_local_decorator(threading.local):
@ -1798,7 +1798,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
# code without tuple conversion. # code without tuple conversion.
device_assignment = tuple(da_object) device_assignment = tuple(da_object)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority): if logger.isEnabledFor(log_priority):
logger.log(log_priority, logger.log(log_priority,
"Compiling %s for with global shapes and types %s. " "Compiling %s for with global shapes and types %s. "
@ -2029,7 +2029,7 @@ def lower_sharding_computation(
transfer_mem_kind_in_jaxpr)) transfer_mem_kind_in_jaxpr))
if not da_object.is_fully_addressable: # type: ignore if not da_object.is_fully_addressable: # type: ignore
if inline and config.jax_spmd_mode != 'allow_all': if inline and config.spmd_mode.value != 'allow_all':
raise RuntimeError( raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this " "Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and " "process (i.e. `Array`s with data sharded across multiple devices and "
@ -2116,7 +2116,7 @@ def lower_mesh_computation(
global_axis_sizes = mesh.shape global_axis_sizes = mesh.shape
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority): if logger.isEnabledFor(log_priority):
logger.log(log_priority, logger.log(log_priority,
"Compiling %s for %s mesh with global shapes and types %s. " "Compiling %s for %s mesh with global shapes and types %s. "

View File

@ -36,6 +36,7 @@ from jax._src import ad_util
from jax._src import api from jax._src import api
from jax._src import api_util from jax._src import api_util
from jax._src import array from jax._src import array
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
@ -45,7 +46,6 @@ from jax._src import pretty_printer as pp
from jax._src import source_info_util from jax._src import source_info_util
from jax._src import util from jax._src import util
from jax._src.abstract_arrays import array_types from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape) raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.interpreters import ad from jax._src.interpreters import ad
@ -93,7 +93,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
raise TypeError(msg) raise TypeError(msg)
assert shapes assert shapes
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# pass dynamic shapes through unchecked # pass dynamic shapes through unchecked
return return
else: else:
@ -182,7 +182,7 @@ def _extract_tracers_dyn_shape(
shape: Sequence[Union[int, core.Tracer]] shape: Sequence[Union[int, core.Tracer]]
) -> tuple[list[core.Tracer], list[Optional[int]]]: ) -> tuple[list[core.Tracer], list[Optional[int]]]:
# Given a sequence representing a shape, pull out Tracers, replacing with None # Given a sequence representing a shape, pull out Tracers, replacing with None
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors # We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information). # raised are different (and have worse source provenance information).
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)] dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
@ -794,7 +794,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
""" """
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array): if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array):
return type_cast(Array, operand) return type_cast(Array, operand)
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors # We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information). # raised are different (and have worse source provenance information).
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
@ -1649,7 +1649,7 @@ def _unbroadcast(aval, x):
return _reduce_sum(x, list(range(len(x_shape)))) return _reduce_sum(x, list(range(len(x_shape))))
else: else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)] dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
if config.jax_enable_checks: assert all(aval.shape[i] == 1 for i in dims) if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
return reshape(_reduce_sum(x, dims), aval.shape) return reshape(_reduce_sum(x, dims), aval.shape)
def _maybe_broadcast(target_shape, x): def _maybe_broadcast(target_shape, x):
@ -3358,7 +3358,7 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
msg = 'reshape new_sizes must all be positive, got {}.' msg = 'reshape new_sizes must all be positive, got {}.'
raise TypeError(msg.format(new_sizes)) raise TypeError(msg.format(new_sizes))
# TODO(necula): re-enable this check # TODO(necula): re-enable this check
if (not config.jax_dynamic_shapes and if (not config.dynamic_shapes.value and
not math.prod(np.shape(operand)) == math.prod(new_sizes)): not math.prod(np.shape(operand)) == math.prod(new_sizes)):
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
raise TypeError(msg.format(new_sizes, np.shape(operand))) raise TypeError(msg.format(new_sizes, np.shape(operand)))
@ -4850,7 +4850,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
# bool(obj) for an ndarray raises an error, so we check len # bool(obj) for an ndarray raises an error, so we check len
if not len(obj): # pylint: disable=g-explicit-length-test if not len(obj): # pylint: disable=g-explicit-length-test
return return
if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and
any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)): any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
return # TODO(mattjj): handle more checks in the dynamic shape case return # TODO(mattjj): handle more checks in the dynamic shape case
obj_arr = np.array(obj) obj_arr = np.array(obj)
@ -4913,17 +4913,17 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[Precision
value to apply to both operands, or as a sequence of two values. value to apply to both operands, or as a sequence of two values.
""" """
if precision is None: if precision is None:
if config.jax_default_matmul_precision is None: if config.default_matmul_precision.value is None:
return None return None
try: try:
return type_cast( return type_cast(
tuple[PrecisionType, PrecisionType], tuple[PrecisionType, PrecisionType],
(Precision(config.jax_default_matmul_precision), (Precision(config.default_matmul_precision.value),
Precision(config.jax_default_matmul_precision))) Precision(config.default_matmul_precision.value)))
except TypeError: except TypeError:
raise ValueError( raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in " "jax_default_matmul_precision flag must be set to None or a value in "
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}" f"{list(Precision._strings)}, but got {config.default_matmul_precision.value}"
) from None ) from None
elif isinstance(precision, str) and precision in Precision._strings: elif isinstance(precision, str) and precision in Precision._strings:
return type_cast(tuple[PrecisionType, PrecisionType], return type_cast(tuple[PrecisionType, PrecisionType],

View File

@ -68,12 +68,13 @@ import operator
from typing import Any, Callable, NamedTuple from typing import Any, Callable, NamedTuple
import weakref import weakref
from jax._src.tree_util import tree_map from jax._src import config
from jax._src.config import config
from jax._src import core from jax._src import core
from jax._src import traceback_util from jax._src import traceback_util
from jax._src.tree_util import tree_map
from jax._src.util import curry from jax._src.util import curry
traceback_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__)
@ -333,13 +334,13 @@ def cache(call: Callable):
def memoized_fun(fun: WrappedFun, *args): def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {}) cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks: if config.check_tracer_leaks.value:
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
config.x64_enabled, config.jax_default_device, config.enable_x64.value, config.default_device.value,
config._trace_context()) config.config._trace_context())
else: else:
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled, key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
config.jax_default_device, config._trace_context()) config.default_device.value, config.config._trace_context())
result = cache.get(key, None) result = cache.get(key, None)
if result is not None: if result is not None:
ans, stores = result ans, stores = result

View File

@ -20,10 +20,10 @@ from typing import Any, Callable, NamedTuple, Optional, TypeVar
import warnings import warnings
from jax._src import dtypes
from jax._src import api from jax._src import api
from jax._src import config
from jax._src import core from jax._src import core
from jax._src.config import config from jax._src import dtypes
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.util import safe_zip, safe_map from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
@ -161,7 +161,7 @@ def _wraps(
try: try:
mod = module or fun.__module__ mod = module or fun.__module__
except AttributeError: except AttributeError:
if config.jax_enable_checks: if config.enable_checks.value:
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.") raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
else: else:
name = f"{mod}.{name}" name = f"{mod}.{name}"
@ -206,7 +206,7 @@ def _wraps(
if kept_sections: if kept_sections:
docstr += "\n" + "\n\n".join(kept_sections) + "\n" docstr += "\n" + "\n\n".join(kept_sections) + "\n"
except: except:
if config.jax_enable_checks: if config.enable_checks.value:
raise raise
docstr = fun.__doc__ docstr = fun.__doc__
@ -229,7 +229,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
return [lax.asarray(arg) for arg in args] return [lax.asarray(arg) for arg in args]
else: else:
shapes = [np.shape(arg) for arg in args] shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
# With dynamic shapes we don't support singleton-dimension broadcasting; # With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround. # we instead broadcast out to the full shape as a temporary workaround.
# TODO(mattjj): revise this workaround # TODO(mattjj): revise this workaround
@ -242,7 +242,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
if len(nonscalar_ranks) < 2: if len(nonscalar_ranks) < 2:
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else: else:
if config.jax_numpy_rank_promotion != "allow": if config.numpy_rank_promotion.value != "allow":
_rank_promotion_warning_or_error(fun_name, shapes) _rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes)) result_rank = len(lax.broadcast_shapes(*shapes))
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
@ -250,13 +250,13 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]): def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
if config.jax_numpy_rank_promotion == "warn": if config.numpy_rank_promotion.value == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to " "Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see " "disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise": elif config.numpy_rank_promotion.value == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} " msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. " "and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see " "For more information, see "

View File

@ -25,6 +25,7 @@ import warnings
import numpy as np import numpy as np
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import stages from jax._src import stages
from jax._src import dispatch from jax._src import dispatch
@ -46,7 +47,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
@ -182,7 +182,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
@wraps(fun) @wraps(fun)
@api_boundary @api_boundary
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
if config.jax_disable_jit: if config.disable_jit.value:
return fun(*args, **kwargs) return fun(*args, **kwargs)
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0] return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
@ -275,7 +275,7 @@ def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, static_argnums, static_argnames, device,
backend, abstracted_axes): backend, abstracted_axes):
if abstracted_axes and not config.jax_dynamic_shapes: if abstracted_axes and not config.dynamic_shapes.value:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
check_callable(fun) check_callable(fun)
@ -432,7 +432,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
dyn_kwargs = {} dyn_kwargs = {}
del kwargs del kwargs
if (donate_argnums or donate_argnames) and not config.jax_debug_nans: if (donate_argnums or donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector( donated_invars = donation_vector(
donate_argnums, donate_argnames, dyn_args, dyn_kwargs) donate_argnums, donate_argnames, dyn_args, dyn_kwargs)
else: else:
@ -462,7 +462,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
assert in_shardings is not None or all(i is not None for i in in_shardings) assert in_shardings is not None or all(i is not None for i in in_shardings)
assert out_shardings is not None or all(o is not None for o in out_shardings) assert out_shardings is not None or all(o is not None for o in out_shardings)
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e) in_avals = tuple(a for a, e in in_type if e)
else: else:
@ -490,7 +490,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
assert len(explicit_args) == len(canonicalized_in_shardings_flat) assert len(explicit_args) == len(canonicalized_in_shardings_flat)
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
implicit_args = _extract_implicit_args(in_type, explicit_args) implicit_args = _extract_implicit_args(in_type, explicit_args)
else: else:
implicit_args = [] implicit_args = []
@ -892,7 +892,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
"pjit in_shardings", in_tree, orig_in_shardings, "pjit in_shardings", in_tree, orig_in_shardings,
tupled_args=True) tupled_args=True)
if not config.jax_dynamic_shapes: if not config.dynamic_shapes.value:
pjit_check_aval_sharding(in_shardings_flat, in_avals, pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.arg_names, None if debug_info is None else debug_info.arg_names,
"pjit arguments", allow_uneven_sharding=False) "pjit arguments", allow_uneven_sharding=False)
@ -909,14 +909,14 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec", "Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for) pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
if config.jax_dynamic_shapes: if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, in_type), debug_info=pe_debug) lu.annotate(fun, in_type), debug_info=pe_debug)
else: else:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=pe_debug) fun, in_type, debug_info=pe_debug)
if not config.jax_dynamic_shapes: if not config.dynamic_shapes.value:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths()) jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
if any(isinstance(c, core.Tracer) for c in consts): if any(isinstance(c, core.Tracer) for c in consts):
@ -944,7 +944,7 @@ def _check_and_canonicalize_out_shardings(
"pjit out_shardings", out_tree(), orig_out_shardings, "pjit out_shardings", out_tree(), orig_out_shardings,
tupled_args=False) tupled_args=False)
if not config.jax_dynamic_shapes: if not config.dynamic_shapes.value:
pjit_check_aval_sharding( pjit_check_aval_sharding(
out_shardings_flat, out_type, out_shardings_flat, out_type,
None if debug_info is None else debug_info.result_paths, None if debug_info is None else debug_info.result_paths,
@ -1132,10 +1132,10 @@ def _pjit_call_impl_python(
lowering_parameters=mlir.LoweringParameters()).compile() lowering_parameters=mlir.LoweringParameters()).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on. # This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks: if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings, pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
jaxpr.jaxpr.debug_info) jaxpr.jaxpr.debug_info)
if config.jax_distributed_debug: if config.distributed_debug.value:
# Defensively only perform fingerprint logic if debug logging is enabled # Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this # NOTE(skyewm): I didn't benchmark this
fingerprint = None fingerprint = None
@ -1151,7 +1151,7 @@ def _pjit_call_impl_python(
try: try:
return compiled.unsafe_call(*args), compiled return compiled.unsafe_call(*args), compiled
except FloatingPointError: except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return _ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
@ -1159,7 +1159,7 @@ def _pjit_call_impl_python(
# but not `fun.call_wrapped` on the same arguments. Let's tell the user. # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
msg = ("An invalid value was encountered in the output of the " msg = ("An invalid value was encountered in the output of the "
f"`jit`-decorated function {name}. Because " f"`jit`-decorated function {name}. Because "
"config.jax_debug_nans and/or config.jax_debug_infs is set, the " "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` " "de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more " "decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not " "precise error message. However, the de-optimized function did not "
@ -1313,7 +1313,7 @@ def pjit_staging_rule(trace, *args, **params):
jaxpr = params['jaxpr'] jaxpr = params['jaxpr']
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
propagate_source_info=False) propagate_source_info=False)
elif config.jax_dynamic_shapes: elif config.dynamic_shapes.value:
source_info = source_info_util.current() source_info = source_info_util.current()
out_tracers = [] out_tracers = []
for aval in _out_type(params['jaxpr']): for aval in _out_type(params['jaxpr']):

View File

@ -31,7 +31,7 @@ from jax import tree_util
from jax._src import ad_util from jax._src import ad_util
from jax._src import api from jax._src import api
from jax._src import basearray from jax._src import basearray
from jax._src import config as config_lib from jax._src import config as config
from jax._src import core from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
@ -40,7 +40,6 @@ from jax._src import sharding_specs
from jax._src import tree_util as tree_util_internal from jax._src import tree_util as tree_util_internal
from jax._src import typing from jax._src import typing
from jax._src.api import jit, vmap from jax._src.api import jit, vmap
from jax._src.config import config
from jax._src.dtypes import float0 from jax._src.dtypes import float0
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
@ -969,7 +968,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array:
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
k1 = convert( k1 = convert(
lax.shift_right_logical(seed, lax_internal._const(seed, 32))) lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
with config_lib.numpy_dtype_promotion('standard'): with config.numpy_dtype_promotion('standard'):
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
# inputs. We should avoid this. # inputs. We should avoid this.
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF))) k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
@ -1270,7 +1269,7 @@ def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
@partial(jit, static_argnums=(1,)) @partial(jit, static_argnums=(1,))
def _threefry_split(key, shape) -> typing.Array: def _threefry_split(key, shape) -> typing.Array:
if config.jax_threefry_partitionable: if config.threefry_partitionable.value:
return _threefry_split_foldlike(key, shape) # type: ignore return _threefry_split_foldlike(key, shape) # type: ignore
else: else:
return _threefry_split_original(key, shape) # type: ignore return _threefry_split_original(key, shape) # type: ignore
@ -1305,7 +1304,7 @@ def threefry_random_bits(key: typing.Array, bit_width, shape):
if bit_width not in (8, 16, 32, 64): if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
if config.jax_threefry_partitionable: if config.threefry_partitionable.value:
return _threefry_random_bits_partitionable(key, bit_width, shape) return _threefry_random_bits_partitionable(key, bit_width, shape)
else: else:
return _threefry_random_bits_original(key, bit_width, shape) return _threefry_random_bits_original(key, bit_width, shape)
@ -1398,7 +1397,7 @@ def _rbg_seed(seed: typing.Array) -> typing.Array:
return jnp.concatenate([halfkey, halfkey]) return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array: def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
if config.jax_threefry_partitionable: if config.threefry_partitionable.value:
_threefry_split = _threefry_split_foldlike _threefry_split = _threefry_split_foldlike
else: else:
_threefry_split = _threefry_split_original _threefry_split = _threefry_split_original

View File

@ -16,8 +16,8 @@ from functools import partial
import operator import operator
from jax._src import api from jax._src import api
from jax._src import config
from jax._src import dtypes as _dtypes from jax._src import dtypes as _dtypes
from jax._src.config import config
from jax._src.tree_util import tree_map, tree_reduce from jax._src.tree_util import tree_map, tree_reduce
import numpy as np import numpy as np
@ -130,7 +130,7 @@ def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
def _check_dtypes_match(xs, ys): def _check_dtypes_match(xs, ys):
def _assert_dtypes_match(x, y): def _assert_dtypes_match(x, y):
if config.x64_enabled: if config.enable_x64.value:
assert _dtype(x) == _dtype(y) assert _dtype(x) == _dtype(y)
else: else:
assert (_dtypes.canonicalize_dtype(_dtype(x)) == assert (_dtypes.canonicalize_dtype(_dtype(x)) ==

View File

@ -26,13 +26,12 @@ import jax.numpy as jnp
from jax import lax from jax import lax
from jax.numpy.linalg import cholesky, svd, eigh from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import config as config_lib from jax._src import config
from jax._src import core from jax._src import core
from jax._src import dtypes from jax._src import dtypes
from jax._src import prng from jax._src import prng
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.api import jit, vmap from jax._src.api import jit, vmap
from jax._src.config import config
from jax._src.core import NamedShape from jax._src.core import NamedShape
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
@ -77,17 +76,17 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
elif _arraylike(key): elif _arraylike(key):
# Call random_wrap here to surface errors for invalid keys. # Call random_wrap here to surface errors for invalid keys.
wrapped_key = prng.random_wrap(key, impl=default_prng_impl()) wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
if config.jax_legacy_prng_key == 'error': if config.legacy_prng_key.value == 'error':
raise ValueError( raise ValueError(
'Legacy uint32 key array passed as key to jax.random function. ' 'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array ' 'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".') 'was intended, set jax_legacy_prng_key="allow".')
elif config.jax_legacy_prng_key == 'warn': elif config.legacy_prng_key.value == 'warn':
warnings.warn( warnings.warn(
'Legacy uint32 key array passed as key to jax.random function. ' 'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array ' 'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2) 'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
elif config.jax_enable_custom_prng: elif config.enable_custom_prng.value:
# TODO(jakevdp): possibly remove this warning condition. # TODO(jakevdp): possibly remove this warning condition.
warnings.warn( warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. ' 'Raw arrays as random keys to jax.random functions are deprecated. '
@ -101,7 +100,7 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
def _return_prng_keys(was_wrapped, key): def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng # TODO(frostig): remove once we always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key) assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if config.jax_enable_custom_prng: if config.enable_custom_prng.value:
return key return key
else: else:
return prng.random_unwrap(key) if was_wrapped else key return prng.random_unwrap(key) if was_wrapped else key
@ -120,7 +119,7 @@ def default_prng_impl():
The default implementation is determined by ``config.jax_default_prng_impl``, The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name. which specifies it by name.
""" """
impl_name = config.jax_default_prng_impl impl_name = config.default_prng_impl.value
assert impl_name in prng.prngs, impl_name assert impl_name in prng.prngs, impl_name
return prng.prngs[impl_name] return prng.prngs[impl_name]
@ -225,8 +224,8 @@ def PRNGKey(seed: Union[int, Array], *,
# TODO(frostig): remove once we always enable_custom_prng # TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name): def _check_default_impl_with_no_custom_prng(impl, name):
default_impl = default_prng_impl() default_impl = default_prng_impl()
default_name = config.jax_default_prng_impl default_name = config.default_prng_impl.value
if not config.jax_enable_custom_prng and default_impl is not impl: if not config.enable_custom_prng.value and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order ' raise RuntimeError('jax_enable_custom_prng must be enabled in order '
f'to seed an RNG with an implementation "{name}" ' f'to seed an RNG with an implementation "{name}" '
f'differing from the default "{default_name}".') f'differing from the default "{default_name}".')
@ -829,7 +828,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
else: # 'cholesky' else: # 'cholesky'
factor = cholesky(cov) factor = cholesky(cov)
normal_samples = normal(key, shape + mean.shape[-1:], dtype) normal_samples = normal(key, shape + mean.shape[-1:], dtype)
with config_lib.numpy_rank_promotion('allow'): with config.numpy_rank_promotion('allow'):
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return result return result

View File

@ -24,11 +24,11 @@ import numpy as np
from jax._src import api_util from jax._src import api_util
from jax._src import ad_util from jax._src import ad_util
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import source_info_util from jax._src import source_info_util
from jax._src import tree_util from jax._src import tree_util
from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
@ -262,7 +262,7 @@ run_state_p.multiple_results = True
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr, def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...]): which_linear: tuple[bool, ...]):
if config.jax_enable_checks: if config.enable_checks.value:
core.check_jaxpr(jaxpr) core.check_jaxpr(jaxpr)
assert len(jaxpr.invars) == len(args) assert len(jaxpr.invars) == len(args)
assert len(which_linear) == len(args) assert len(which_linear) == len(args)

View File

@ -27,7 +27,7 @@ from typing import Any, Callable
import jax import jax
from jax import core from jax import core
from jax._src.config import config from jax._src import config
from jax._src.lib import tpu_mosaic from jax._src.lib import tpu_mosaic
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax.interpreters import mlir from jax.interpreters import mlir
@ -38,7 +38,7 @@ from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager from jaxlib.mlir.passmanager import PassManager
import numpy as np import numpy as np
config.define_bool_state( mosaic_use_cpp_passes = config.config.define_bool_state(
name="mosaic_use_cpp_passes", name="mosaic_use_cpp_passes",
default=False, default=False,
help=( help=(
@ -54,13 +54,13 @@ tpu = tpu_mosaic.tpu
apply_vector_layout = tpu_mosaic.apply_vector_layout apply_vector_layout = tpu_mosaic.apply_vector_layout
infer_memref_layout = tpu_mosaic.infer_memref_layout infer_memref_layout = tpu_mosaic.infer_memref_layout
config.define_bool_state( mosaic_allow_hlo = config.config.define_bool_state(
name="jax_mosaic_allow_hlo", name="jax_mosaic_allow_hlo",
default=False, default=False,
help="Allow hlo dialects in Mosaic", help="Allow hlo dialects in Mosaic",
) )
config.define_bool_state( mosaic_dump_mlir = config.config.define_bool_state(
name="jax_mosaic_dump_mlir", name="jax_mosaic_dump_mlir",
default=False, default=False,
help="Print mlir module after each pass", help="Print mlir module after each pass",
@ -243,7 +243,7 @@ def _lower_tpu_kernel(
) )
dump_mlir(module, "initial module") dump_mlir(module, "initial module")
if config.jax_mosaic_allow_hlo: if mosaic_allow_hlo.value:
# Run hlo dialect conversion: hlo -> linalg -> vector. # Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [ pipeline = [
"hlo-legalize-to-arithmetic", "hlo-legalize-to-arithmetic",
@ -255,7 +255,7 @@ def _lower_tpu_kernel(
) )
dump_mlir(module, "after hlo conversion module") dump_mlir(module, "after hlo conversion module")
if config.mosaic_use_cpp_passes: if mosaic_use_cpp_passes.value:
pipeline = [ pipeline = [
( (
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})" f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
@ -278,7 +278,7 @@ def _lower_tpu_kernel(
module.operation.verify() module.operation.verify()
dump_mlir(module, "after infer vector layout pass") dump_mlir(module, "after infer vector layout pass")
if config.mosaic_use_cpp_passes: if mosaic_use_cpp_passes.value:
pipeline = [ pipeline = [
( (
"func.func(tpu-apply-vector-layout{sublane-count=8" "func.func(tpu-apply-vector-layout{sublane-count=8"
@ -418,6 +418,6 @@ def _lowered_as_tpu_kernel(
def dump_mlir(module: ir.Module, msg: str): def dump_mlir(module: ir.Module, msg: str):
"""A helper function to print mlir module with a message.""" """A helper function to print mlir module with a message."""
if config.jax_mosaic_dump_mlir: if mosaic_dump_mlir.value:
print(f"[jax_mosaic_dump_mlir] {msg}") print(f"[jax_mosaic_dump_mlir] {msg}")
print(module) print(module)

View File

@ -19,9 +19,9 @@ import traceback
import types import types
from typing import Any, Callable, Optional, TypeVar, cast from typing import Any, Callable, Optional, TypeVar, cast
from jax._src.config import config from jax._src import config
from jax._src.lib import xla_extension
from jax._src import util from jax._src import util
from jax._src.lib import xla_extension
C = TypeVar("C", bound=Callable[..., Any]) C = TypeVar("C", bound=Callable[..., Any])
@ -139,7 +139,7 @@ def _ipython_supports_tracebackhide() -> bool:
return IPython.version_info[:2] >= (7, 17) return IPython.version_info[:2] >= (7, 17)
def _filtering_mode() -> str: def _filtering_mode() -> str:
mode = config.jax_traceback_filtering mode = config.traceback_filtering.value
if mode is None or mode == "auto": if mode is None or mode == "auto":
if (_running_under_ipython() and _ipython_supports_tracebackhide()): if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide" mode = "tracebackhide"

View File

@ -23,9 +23,9 @@ import weakref
import numpy as np import numpy as np
from jax._src import config
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import utils as jaxlib_utils from jax._src.lib import utils as jaxlib_utils
from jax._src.config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -257,10 +257,10 @@ def cache(max_size=4096):
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if config.jax_check_tracer_leaks: if config.check_tracer_leaks.value:
return f(*args, **kwargs) return f(*args, **kwargs)
else: else:
return cached(config._trace_context(), *args, **kwargs) return cached(config.config._trace_context(), *args, **kwargs)
wrapper.cache_clear = cached.cache_clear wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info wrapper.cache_info = cached.cache_info
@ -278,7 +278,7 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
behave similar to `functools.lru_cache`. behave similar to `functools.lru_cache`.
""" """
global _weakref_lru_caches global _weakref_lru_caches
cached_call = xc.weakref_lru_cache(config._trace_context, call, maxsize) cached_call = xc.weakref_lru_cache(config.config._trace_context, call, maxsize)
_weakref_lru_caches.add(cached_call) _weakref_lru_caches.add(cached_call)
return cached_call return cached_call
@ -464,7 +464,7 @@ def distributed_debug_log(*pairs):
pairs: A sequence of label/value pairs to log. The first pair is treated as pairs: A sequence of label/value pairs to log. The first pair is treated as
a heading for subsequent pairs. a heading for subsequent pairs.
""" """
if config.jax_distributed_debug: if config.distributed_debug.value:
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"] lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
try: try:
lines.append(f"{pairs[0][0]}: {pairs[0][1]}") lines.append(f"{pairs[0][0]}: {pairs[0][1]}")

View File

@ -35,9 +35,8 @@ import threading
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import warnings import warnings
from jax._src import config
from jax._src import distributed from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import config
from jax._src.lib import cuda_versions from jax._src.lib import cuda_versions
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax._src.lib import xla_extension from jax._src.lib import xla_extension
@ -63,34 +62,34 @@ XlaBackend = xla_client.Client
# TODO(phawkins): Remove jax_xla_backend. # TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = jax_config.DEFINE_string( _XLA_BACKEND = config.DEFINE_string(
'jax_xla_backend', '', 'jax_xla_backend', '',
'Deprecated, please use --jax_platforms instead.') 'Deprecated, please use --jax_platforms instead.')
BACKEND_TARGET = jax_config.DEFINE_string( BACKEND_TARGET = config.DEFINE_string(
'jax_backend_target', 'jax_backend_target',
os.getenv('JAX_BACKEND_TARGET', '').lower(), os.getenv('JAX_BACKEND_TARGET', '').lower(),
'Either "local" or "rpc:address" to connect to a remote service target.') 'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit # TODO(skye): warn when this is used once we test out --jax_platforms a bit
_PLATFORM_NAME = jax_config.DEFINE_string( _PLATFORM_NAME = config.DEFINE_string(
'jax_platform_name', 'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(), os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.') 'Deprecated, please use --jax_platforms instead.')
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string( CUDA_VISIBLE_DEVICES = config.DEFINE_string(
'jax_cuda_visible_devices', 'all', 'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.') 'comma-separate list of integer device IDs.')
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string( _ROCM_VISIBLE_DEVICES = config.DEFINE_string(
'jax_rocm_visible_devices', 'all', 'jax_rocm_visible_devices', 'all',
'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.') 'comma-separate list of integer device IDs.')
_USE_MOCK_GPU_CLIENT = jax_config.DEFINE_bool( _USE_MOCK_GPU_CLIENT = config.DEFINE_bool(
name="use_mock_gpu_client", name="use_mock_gpu_client",
default=False, default=False,
help="If True, use a mock GPU client instead of a real one.", help="If True, use a mock GPU client instead of a real one.",
) )
_MOCK_NUM_GPUS = jax_config.DEFINE_integer( _MOCK_NUM_GPUS = config.DEFINE_integer(
name="mock_num_gpus", name="mock_num_gpus",
default=1, default=1,
help="Mock GPU client number of gpus.", help="Mock GPU client number of gpus.",
@ -223,7 +222,7 @@ def _check_cuda_versions():
def make_gpu_client( def make_gpu_client(
*, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str] *, platform_name: str, visible_devices_flag: config.FlagHolder[str]
) -> xla_client.Client: ) -> xla_client.Client:
visible_devices = visible_devices_flag.value visible_devices = visible_devices_flag.value
allowed_devices = None allowed_devices = None
@ -564,11 +563,10 @@ def backends() -> dict[str, xla_client.Client]:
with _backend_lock: with _backend_lock:
if _backends: if _backends:
return _backends return _backends
if config.jax_platforms: if jax_platforms := config.jax_platforms.value:
jax_platforms = config.jax_platforms.split(",")
platforms = [] platforms = []
# Allow platform aliases in the list of platforms. # Allow platform aliases in the list of platforms.
for platform in jax_platforms: for platform in jax_platforms.split(","):
platforms.extend(expand_platform_alias(platform)) platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1) priorities = range(len(platforms), 0, -1)
# If the user specified a list of platforms explicitly, always fail # If the user specified a list of platforms explicitly, always fail
@ -597,14 +595,14 @@ def backends() -> dict[str, xla_client.Client]:
_backend_errors[platform] = str(err) _backend_errors[platform] = str(err)
logger.info(err_msg) logger.info(err_msg)
else: else:
if config.jax_platforms: if config.jax_platforms.value:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
else: else:
err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)" err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
assert _default_backend is not None assert _default_backend is not None
if not config.jax_platforms: if not config.jax_platforms.value:
_suggest_missing_backends() _suggest_missing_backends()
return _backends return _backends

View File

@ -55,6 +55,7 @@ import numpy as np
import jax import jax
from jax import lax from jax import lax
from jax._src import config
from jax._src import core from jax._src import core
from jax._src.custom_derivatives import lift_jvp from jax._src.custom_derivatives import lift_jvp
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -67,7 +68,6 @@ from jax._src.lib import pytree
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config
from jax._src.lax.control_flow import _check_tree_and_avals from jax._src.lax.control_flow import _check_tree_and_avals
from jax._src.numpy import lax_numpy from jax._src.numpy import lax_numpy
from jax.experimental import sparse from jax.experimental import sparse
@ -614,7 +614,7 @@ def _add_sparse(spenv, *spvalues):
raise NotImplementedError("Addition between sparse matrices of different shapes.") raise NotImplementedError("Addition between sparse matrices of different shapes.")
if X.indices_ref == Y.indices_ref: if X.indices_ref == Y.indices_ref:
out_data = lax.add(spenv.data(X), spenv.data(Y)) out_data = lax.add(spenv.data(X), spenv.data(Y))
if config.jax_enable_checks: if config.enable_checks.value:
assert X.indices_sorted == Y.indices_sorted assert X.indices_sorted == Y.indices_sorted
assert X.unique_indices == Y.unique_indices assert X.unique_indices == Y.unique_indices
out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref, out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref,
@ -657,7 +657,7 @@ def _mul_sparse(spenv, *spvalues):
X, Y = spvalues X, Y = spvalues
if X.is_sparse() and Y.is_sparse(): if X.is_sparse() and Y.is_sparse():
if X.indices_ref == Y.indices_ref and X.unique_indices: if X.indices_ref == Y.indices_ref and X.unique_indices:
if config.jax_enable_checks: if config.enable_checks.value:
assert X.indices_sorted == Y.indices_sorted assert X.indices_sorted == Y.indices_sorted
assert X.unique_indices == Y.unique_indices assert X.unique_indices == Y.unique_indices
out_data = lax.mul(spenv.data(X), spenv.data(Y)) out_data = lax.mul(spenv.data(X), spenv.data(Y))