Migrate a subset of internal modules to use state objects

The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
This commit is contained in:
Sergei Lebedev 2023-10-09 07:28:18 -07:00 committed by jax authors
parent 14414363d0
commit 65d3058944
27 changed files with 232 additions and 212 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
from jax._src import ad_util
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
@ -31,7 +32,6 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -529,7 +529,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if config.jax_log_checkpoint_residuals
logger.log(logging.WARNING if config.log_checkpoint_residuals.value
else logging.DEBUG,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
@ -659,7 +659,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
assert not jaxpr.constvars
if differentiated and prevent_cse:
if config.jax_remat_opt_barrier:
if config.remat_opt_barrier.value:
translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform:
translation_rule = _remat_translation_using_while

View File

@ -27,13 +27,13 @@ from jax._src import abstract_arrays
from jax._src import api
from jax._src import api_util
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.config import config
from jax._src.lib import xla_client as xc
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
@ -172,7 +172,7 @@ class ArrayImpl(basearray.Array):
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, xmap, etc).
if not _skip_checks or config.jax_enable_checks:
if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange()
def _check_and_rearrange(self):

View File

@ -20,7 +20,7 @@ import os
import struct
import sys
from jax._src.config import config
from jax._src import config
from jax._src.lib import version as jaxlib_version
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client
@ -167,7 +167,7 @@ def _canonicalize_ir(m_original: ir.Module) -> bytes:
def _hash_computation(hash_obj, module):
if config.jax_compilation_cache_include_metadata_in_key:
if config.compilation_cache_include_metadata_in_key.value:
canonical_ir = _serialize_ir(module)
else:
canonical_ir = _canonicalize_ir(module)

View File

@ -27,6 +27,7 @@ from jax import lax
from jax._src import api
from jax._src import linear_util as lu
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
@ -37,7 +38,6 @@ from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -511,7 +511,7 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
if not config.jax_experimental_unsafe_xla_runtime_errors:
if not config.xla_runtime_errors.value:
raise functionalization_error
out_op, _, _ = mlir.emit_python_callback(

View File

@ -30,32 +30,31 @@ import numpy as np
from jax._src import lib
from jax._src import compilation_cache
from jax._src import config as jax_config
from jax._src import config as config
from jax._src import monitoring
from jax._src import path
from jax._src import profiler
from jax._src import traceback_util
from jax._src.config import config
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool(
'jax_disable_most_optimizations',
jax_config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
_DUMP_IR_TO = jax_config.DEFINE_string(
_DUMP_IR_TO = config.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")
_COMPILER_DETAILED_LOGGING_MIN_OPS = jax_config.DEFINE_integer(
_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
"jax_compiler_detailed_logging_min_ops",
jax_config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
help=(
'How big should a module be in MLIR operations before JAX enables '
'detailed compiler logging? The intent of this flag is to suppress '
@ -192,7 +191,7 @@ def get_compile_options(
# If the function returns 0, set -1; this is an error.
# -1 indicates that no attempt should be made to retrieve the latest profile
# later on.
jax_xla_profile_version = config.jax_xla_profile_version
jax_xla_profile_version = config.jax_xla_profile_version.value
if jax_xla_profile_version > 0:
compile_options.profile_version = jax_xla_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
@ -306,7 +305,7 @@ def compile_or_get_cached(
try:
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation,
config.config.jax_use_original_compilation_cache_key_generation,
)
except xc._xla.XlaRuntimeError as ex:
logger.error("compile_or_get_cached: unable to generate cache key, "
@ -325,7 +324,7 @@ def compile_or_get_cached(
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
# rate after it is enabled.
if jax_config.config.jax_use_original_compilation_cache_key_generation:
if config.config.jax_use_original_compilation_cache_key_generation:
# TODO(b/293308239) Remove metrics for the original cache after the new
# compilation cache key implementation is fully rolled out.
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
@ -358,7 +357,7 @@ def _cache_read(
return compilation_cache.get_executable_and_time(
cache_key, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
if config.raise_persistent_cache_errors.value:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
@ -380,7 +379,7 @@ def _cache_write(cache_key: str,
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
return
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
min_compile_time = config.persistent_cache_min_compile_time_secs.value
if min_compile_time:
if compile_time_secs < min_compile_time:
logger.debug(
@ -399,7 +398,7 @@ def _cache_write(cache_key: str,
compilation_cache.put_executable_and_time(
cache_key, module_name, executable, backend, int(compile_time_secs))
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
if config.raise_persistent_cache_errors.value:
raise
warnings.warn(
f"Error writing persistent compilation cache entry for "

View File

@ -22,7 +22,7 @@ import logging
import os
import sys
import threading
from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar
from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
from jax._src import lib
from jax._src.lib import jax_jit
@ -75,6 +75,12 @@ class FlagHolder(Generic[_T]):
self._flags = flags
self._name = name
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
@property
def value(self) -> _T:
return getattr(self._flags, self._name)
@ -523,6 +529,12 @@ class _StateContextManager(Generic[_T]):
self._validate_new_val_hook = validate_new_val_hook
self._default_value = default_value
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
@property
def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset)
@ -935,7 +947,7 @@ use_original_compilation_cache_key_generation = config.define_bool_state(
"deployed, this flag and the original cache-key generation algorithm "
"will be removed.")
config.define_enum_state(
default_dtype_bits = config.define_enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
default='64',
@ -1090,7 +1102,7 @@ bcoo_cusparse_lowering = config.define_bool_state(
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result
config.define_bool_state(
dynamic_shapes = config.define_bool_state(
name='jax_dynamic_shapes',
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
help=('Enables experimental features for staging out computations with '
@ -1102,19 +1114,19 @@ config.define_bool_state(
# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
config.define_bool_state(
remat_opt_barrier = config.define_bool_state(
name='jax_remat_opt_barrier',
default=(lib.version >= (0, 3, 6)),
help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(sharadmv,mattjj): set default to True, then remove
config.define_bool_state(
eager_pmap = config.define_bool_state(
name='jax_eager_pmap',
default=True,
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
config.define_bool_state(
xla_runtime_errors = config.define_bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '

View File

@ -36,9 +36,8 @@ from weakref import ref
import numpy as np
from jax._src import dtypes
from jax._src import config as jax_config
from jax._src import config
from jax._src import effects
from jax._src.config import config
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
@ -60,9 +59,9 @@ zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer(
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer(
'jax_tracer_error_num_traceback_frames',
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
@ -269,7 +268,7 @@ class JaxprEqn(NamedTuple):
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
source_info = source_info or source_info_util.new_source_info()
if config.jax_enable_checks:
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
@ -381,7 +380,7 @@ class Primitive:
return f'{self.name}'
def bind(self, *args, **params):
assert (not config.jax_enable_checks or
assert (not config.enable_checks.value or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
return self.bind_with_trace(find_top_trace(args), args, params)
@ -438,7 +437,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None:
if config.jax_enable_checks and not config.jax_dynamic_shapes:
if config.enable_checks.value and not config.dynamic_shapes.value:
assert typecheck(v.aval, val), (v.aval, val)
env[v] = val
@ -739,7 +738,7 @@ class Tracer(typing.Array):
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert not config.jax_enable_checks or name != "aval"
assert not config.enable_checks.value or name != "aval"
try:
attr = getattr(self.aval, name)
@ -989,7 +988,7 @@ def _update_thread_local_jit_state(dynamic):
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
# alive by the JIT cache, particularly for nested JIT-ted functions.
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
config.update_thread_local_jit_state(dynamic_trace_state=copy)
# The global state of the tracer is accessed by a thread-local object.
@ -1015,7 +1014,7 @@ def _initialize_jax_jit_thread_local_state():
if tls.extra_jit_context is None:
dynamic = thread_local_state.trace_state.trace_stack.dynamic
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
config.update_thread_local_jit_state(dynamic_trace_state=copy)
jax_jit.set_thread_local_state_initialization_callback(
@ -1151,7 +1150,7 @@ def new_main(trace_type: type[Trace], dynamic: bool = False,
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
if config.jax_check_tracer_leaks:
if config.check_tracer_leaks.value:
t = ref(main)
del main
if t() is not None:
@ -1188,7 +1187,7 @@ def new_base_main(trace_type: type[Trace],
stack.stack[0] = prev_base
_update_thread_local_jit_state(stack.dynamic)
if config.jax_check_tracer_leaks:
if config.check_tracer_leaks.value:
t = ref(main)
del main
if t() is not None:
@ -1268,7 +1267,7 @@ def new_sublevel() -> Generator[None, None, None]:
finally:
thread_local_state.trace_state.substack.pop()
if config.jax_check_tracer_leaks:
if config.check_tracer_leaks.value:
t = ref(sublevel)
del sublevel
if t() is not None:
@ -2026,9 +2025,9 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
return dim
elif (config.jax_dynamic_shapes and isinstance(dim, DArray) and
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_dim(dim):
@ -2230,7 +2229,7 @@ class CallPrimitive(Primitive):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
@ -2458,14 +2457,14 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag)
ts = thread_local_state.trace_state
ts.axis_env.append(frame)
jax_config.update_thread_local_jit_state(
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
ts.axis_env.pop()
jax_config.update_thread_local_jit_state(
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
@ -2474,14 +2473,14 @@ def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
ts = thread_local_state.trace_state
ts.axis_env.extend(frames)
jax_config.update_thread_local_jit_state(
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
for _ in frames: ts.axis_env.pop()
jax_config.update_thread_local_jit_state(
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
@ -2493,12 +2492,12 @@ def stash_axis_env():
# be raised.
ts = thread_local_state.trace_state
prev_axis_env, ts.axis_env = ts.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=())
config.update_thread_local_jit_state(axis_env_state=())
try:
yield
finally:
ts.axis_env = prev_axis_env
jax_config.update_thread_local_jit_state(
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))

View File

@ -18,6 +18,7 @@ from functools import update_wrapper, reduce, partial
import inspect
from typing import Any, Callable, Generic, Optional, TypeVar
from jax._src import config
from jax._src import core
from jax._src import custom_api_util
from jax._src.custom_transpose import custom_transpose
@ -28,7 +29,6 @@ from jax._src import traceback_util
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.config import config
from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
@ -590,7 +590,7 @@ class custom_vjp(Generic[ReturnValue]):
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if config.jax_enable_custom_vjp_by_custom_transpose:
if config.enable_custom_vjp_by_custom_transpose.value:
if self.nondiff_argnums:
raise NotImplementedError(
'nondiff_argnums not implemented for new custom_vjp')
@ -1072,7 +1072,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args))
if config.jax_check_tracer_leaks:
if config.check_tracer_leaks.value:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)

View File

@ -30,6 +30,7 @@ import numpy as np
import jax
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
@ -39,7 +40,6 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -160,7 +160,7 @@ def xla_primitive_callable(
lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile()
if xla_extension_version >= 192:
if config.jax_disable_jit:
if config.disable_jit.value:
call = compiled.unsafe_call
else:
call = compiled.create_cpp_call_for_apply_primitive(out_tree())
@ -262,7 +262,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
if _on_exit:
yield
else:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
start_time = time.time()
yield
elapsed_time = time.time() - start_time
@ -395,7 +395,7 @@ def _initial_style_primitive_replicas(params: dict[str, Any]) -> int:
default=1)
def needs_check_special() -> bool:
return config.jax_debug_infs or config.jax_debug_nans
return config.debug_infs.value or config.debug_nans.value
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
if needs_check_special():
@ -404,9 +404,9 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact):
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")

View File

@ -19,7 +19,7 @@ import os
from typing import Any, Optional, Union
from jax._src import clusters
from jax._src.config import config
from jax._src import config
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
@ -63,8 +63,8 @@ class State:
if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
config.update("jax_cuda_visible_devices", visible_devices)
config.update("jax_rocm_visible_devices", visible_devices)
config.config.update("jax_cuda_visible_devices", visible_devices)
config.config.update("jax_rocm_visible_devices", visible_devices)
self.process_id = process_id

View File

@ -30,7 +30,7 @@ import warnings
import ml_dtypes
import numpy as np
from jax._src.config import config
from jax._src import config
from jax._src.typing import DType, DTypeLike
from jax._src import traceback_util
@ -59,7 +59,6 @@ class extended(np.generic):
>>> jnp.issubdtype(key.dtype, dtypes.extended)
True
"""
pass
class prng_key(extended):
@ -75,7 +74,6 @@ class prng_key(extended):
>>> jnp.issubdtype(key.dtype, dtypes.prng_key)
True
"""
pass
class ExtendedDType(metaclass=abc.ABCMeta):
@ -139,12 +137,28 @@ _int4_dtypes = [
]
# Default types.
bool_: type = np.bool_
int_: type = np.int32 if config.jax_default_dtype_bits == '32' else np.int64
uint: type = np.uint32 if config.jax_default_dtype_bits == '32' else np.uint64
float_: type = np.float32 if config.jax_default_dtype_bits == '32' else np.float64
complex_: type = np.complex64 if config.jax_default_dtype_bits == '32' else np.complex128
_default_types: dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
bool_ = np.bool_
int_: type[Any]
uint: type[Any]
float_: type[Any]
complex_: type[Any]
if config.default_dtype_bits.value == '32':
int_ = np.int32
uint = np.uint32
float_ = np.float32
complex_ = np.complex64
else:
int_ = np.int64
uint = np.uint64
float_ = np.float64
complex_ = np.complex128
_default_types: dict[str, type[Any]] = {
'b': bool_,
'i': int_,
'u': uint,
'f': float_,
'c': complex_,
}
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)])
@ -219,7 +233,7 @@ def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opa
"allow_opaque_dtype argument is deprecated; use allow_extended_dtype.",
DeprecationWarning)
allow_extended_dtype = allow_opaque_dtype
return _canonicalize_dtype(config.x64_enabled, allow_extended_dtype, dtype) # type: ignore[bad-return-type]
return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # type: ignore[bad-return-type]
# Default dtypes corresponding to Python scalars.
python_scalar_dtypes : dict[type, DType] = {
@ -507,7 +521,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
if len(LUB) == 1:
return LUB.pop()
elif len(LUB) == 0:
if config.jax_numpy_dtype_promotion == 'strict':
if config.numpy_dtype_promotion.value == 'strict':
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting "
@ -553,7 +567,7 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
# object identity, not object equality, due to the behavior of np.dtype.__eq__
a_tp = cast(JAXType, a if any(a is t for t in _weak_types) else np.dtype(a))
b_tp = cast(JAXType, b if any(b is t for t in _weak_types) else np.dtype(b))
return np.dtype(_least_upper_bound(config.jax_numpy_dtype_promotion, a_tp, b_tp))
return np.dtype(_least_upper_bound(config.numpy_dtype_promotion.value, a_tp, b_tp))
def is_weakly_typed(x: Any) -> bool:
try:
@ -604,17 +618,17 @@ def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
# Trivial promotion case. This allows extended dtypes through.
out_dtype = dtypes[0]
out_weak_type = False
elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
elif all(weak_types) and config.numpy_dtype_promotion.value != 'strict':
# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16).
# TODO(jakevdp): explore removing this special case.
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
*{_jax_type(dtype, False) for dtype in dtypes})
out_dtype = dtype(result_type)
out_weak_type = True
else:
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
out_dtype = dtype(result_type)
out_weak_type = any(result_type is t for t in _weak_types)

View File

@ -29,6 +29,7 @@ from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
import warnings
from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import effects as effects_lib
@ -38,7 +39,6 @@ from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
@ -316,9 +316,8 @@ register_constant_handler(core.Token, _token_constant_handler)
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
if pattern := config.hlo_source_file_canonicalization_regex.value:
source_file = re.sub(pattern, '', source_file)
return source_file
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
@ -349,7 +348,7 @@ def _source_info_to_location(
name_stack: source_info_util.NameStack) -> ir.Location:
eqn_str = (f'{str(source_info.name_stack)}/'
f'{core.str_eqn_compact(primitive.name, params)}')
if config.jax_include_full_tracebacks_in_locations:
if config.include_full_tracebacks_in_locations.value:
if source_info.traceback is None:
loc = ir.Location.unknown()
else:
@ -622,7 +621,7 @@ def sharded_aval(aval: core.AbstractValue,
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> tuple[int | Value, ...]:
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
ctx = ctx.replace(
@ -774,7 +773,7 @@ def lower_jaxpr_to_module(
host_callbacks: list[Any] = []
dim_vars: Sequence[str]
if not config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
# Find the dimension variables
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
for d in aval.shape if not core.is_constant_dim(d)]
@ -1418,7 +1417,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values)
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
axis_size_env = {d: read(d)[0]
for a in avals_in if type(a) is core.DShapedArray
for d in a.shape if type(d) is core.Var}
@ -1589,7 +1588,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# We might be applying this function to arguments with dynamic shapes,
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
# case, we need to form a jaxpr with leading binders for those axis size

View File

@ -26,13 +26,13 @@ from weakref import ref
import numpy as np
from jax._src import linear_util as lu
from jax._src.config import config
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import effects
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import profiler
from jax._src import source_info_util
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
@ -106,7 +106,7 @@ class PartialVal(tuple):
"""
def __new__(cls, xs: tuple[AbstractValue | None, core.Value]):
pv, const = xs
if config.jax_enable_checks:
if config.enable_checks.value:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert (const is None or isinstance(const, core.Tracer) or
@ -255,7 +255,7 @@ class JaxprTrace(Trace['JaxprTracer']):
# which were unknown to the first call (corresponding to in_avals).
# Wrap f to perform the partial evaluation and plumb out aux data.
if not config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
tuple(in_avals))
@ -275,7 +275,7 @@ class JaxprTrace(Trace['JaxprTracer']):
out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)])
# Form the complete list of residuals by forwarding some inputs.
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# With dynamic shapes, we may need to forward implicit arguments.
in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns)
in_consts_full = [None] * len(f.in_type)
@ -303,7 +303,7 @@ class JaxprTrace(Trace['JaxprTracer']):
staged_params = update_params(staged_params, map(op.not_, in_knowns),
num_new_args)
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# With dynamic shapes, we may need to substitute Tracers into avals.
out_tracers = []
for aval, _ in out_type:
@ -958,21 +958,21 @@ def tracers_to_jaxpr(
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
outvars, eqns, jaxpr_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects, debug_info=dbg)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@weakref_lru_cache
@ -980,24 +980,24 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
raise NotImplementedError
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n])
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names[n:])
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
debug_info=dbg)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
raise NotImplementedError
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr
@ -1090,7 +1090,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
# check jaxpr_known and jaxpr_unknown in isolation
# TODO(mattjj): enable weak type checking here
if config.jax_enable_checks:
if config.enable_checks.value:
core.check_jaxpr(jaxpr_known)
core.check_jaxpr(jaxpr_unknown)
# check jaxpr_known has input type corresponding to known inputs of jaxpr
@ -1256,7 +1256,7 @@ def _partial_eval_jaxpr_custom_cached(
known_outvars, known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
known_eqns, known_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
config.enable_checks.value and core.check_jaxpr(jaxpr_known)
_, ins_staged = partition_list(in_inst, jaxpr.invars)
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
@ -1265,7 +1265,7 @@ def _partial_eval_jaxpr_custom_cached(
outs_staged, staged_eqns)
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns, staged_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
config.enable_checks.value and core.check_jaxpr(jaxpr_staged)
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
len(non_input_res_refs))
@ -1486,7 +1486,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs
@ -1702,7 +1702,7 @@ class JaxprStackFrame:
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
jaxpr, out_type = _add_implicit_outputs(jaxpr)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, out_type, constvals
def newvar(self, aval):
@ -2226,7 +2226,7 @@ def trace_to_subjaxpr_dynamic(
out_tracers = map(trace.full_raise, ans)
jaxpr, consts = frame.to_jaxpr(out_tracers)
del fun, main, trace, frame, in_tracers, out_tracers, ans
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts
@ -2432,7 +2432,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
jaxpr.effects, jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return new_jaxpr, out_type

View File

@ -34,8 +34,9 @@ import jax
from jax.errors import JAXTypeError
from jax._src import api_util
from jax._src import core
from jax._src import compiler
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
@ -51,7 +52,6 @@ from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import DShapedArray
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
@ -279,7 +279,7 @@ def xla_pmap_impl_lazy(
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and
if (config.disable_jit.value and config.eager_pmap.value and
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
@ -296,7 +296,7 @@ def xla_pmap_impl_lazy(
is_explicit_global_axis_size, *abstract_args)
# Don't re-abstractify args unless logging is enabled for performance.
if config.jax_distributed_debug:
if config.distributed_debug.value:
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
@ -433,7 +433,7 @@ class MapTrace(core.Trace):
def process_map(self, map_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit:
if not config.disable_jit.value:
bind = HashableFunction(
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
(map_primitive, fun))
@ -728,7 +728,7 @@ def lower_parallel_callable(
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
f"{replicas.jaxpr_replicas}")
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
@ -1630,7 +1630,7 @@ ShardingInfo = tuple[
def _get_default_device() -> xc.Device:
return config.jax_default_device or xb.local_devices()[0]
return config.default_device.value or xb.local_devices()[0]
class _thread_local_decorator(threading.local):
@ -1798,7 +1798,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
# code without tuple conversion.
device_assignment = tuple(da_object)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s for with global shapes and types %s. "
@ -2029,7 +2029,7 @@ def lower_sharding_computation(
transfer_mem_kind_in_jaxpr))
if not da_object.is_fully_addressable: # type: ignore
if inline and config.jax_spmd_mode != 'allow_all':
if inline and config.spmd_mode.value != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
@ -2116,7 +2116,7 @@ def lower_mesh_computation(
global_axis_sizes = mesh.shape
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s for %s mesh with global shapes and types %s. "

View File

@ -36,6 +36,7 @@ from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -45,7 +46,6 @@ from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.interpreters import ad
@ -93,7 +93,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
raise TypeError(msg)
assert shapes
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# pass dynamic shapes through unchecked
return
else:
@ -182,7 +182,7 @@ def _extract_tracers_dyn_shape(
shape: Sequence[Union[int, core.Tracer]]
) -> tuple[list[core.Tracer], list[Optional[int]]]:
# Given a sequence representing a shape, pull out Tracers, replacing with None
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
@ -794,7 +794,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
"""
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array):
return type_cast(Array, operand)
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
@ -1649,7 +1649,7 @@ def _unbroadcast(aval, x):
return _reduce_sum(x, list(range(len(x_shape))))
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
if config.jax_enable_checks: assert all(aval.shape[i] == 1 for i in dims)
if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
return reshape(_reduce_sum(x, dims), aval.shape)
def _maybe_broadcast(target_shape, x):
@ -3358,7 +3358,7 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
msg = 'reshape new_sizes must all be positive, got {}.'
raise TypeError(msg.format(new_sizes))
# TODO(necula): re-enable this check
if (not config.jax_dynamic_shapes and
if (not config.dynamic_shapes.value and
not math.prod(np.shape(operand)) == math.prod(new_sizes)):
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
raise TypeError(msg.format(new_sizes, np.shape(operand)))
@ -4850,7 +4850,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
# bool(obj) for an ndarray raises an error, so we check len
if not len(obj): # pylint: disable=g-explicit-length-test
return
if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and
if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and
any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
return # TODO(mattjj): handle more checks in the dynamic shape case
obj_arr = np.array(obj)
@ -4913,17 +4913,17 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[Precision
value to apply to both operands, or as a sequence of two values.
"""
if precision is None:
if config.jax_default_matmul_precision is None:
if config.default_matmul_precision.value is None:
return None
try:
return type_cast(
tuple[PrecisionType, PrecisionType],
(Precision(config.jax_default_matmul_precision),
Precision(config.jax_default_matmul_precision)))
(Precision(config.default_matmul_precision.value),
Precision(config.default_matmul_precision.value)))
except TypeError:
raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in "
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}"
f"{list(Precision._strings)}, but got {config.default_matmul_precision.value}"
) from None
elif isinstance(precision, str) and precision in Precision._strings:
return type_cast(tuple[PrecisionType, PrecisionType],

View File

@ -68,12 +68,13 @@ import operator
from typing import Any, Callable, NamedTuple
import weakref
from jax._src.tree_util import tree_map
from jax._src.config import config
from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import tree_map
from jax._src.util import curry
traceback_util.register_exclusion(__file__)
@ -333,13 +334,13 @@ def cache(call: Callable):
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
if config.check_tracer_leaks.value:
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
config.x64_enabled, config.jax_default_device,
config._trace_context())
config.enable_x64.value, config.default_device.value,
config.config._trace_context())
else:
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
config.jax_default_device, config._trace_context())
key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
config.default_device.value, config.config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result

View File

@ -20,10 +20,10 @@ from typing import Any, Callable, NamedTuple, Optional, TypeVar
import warnings
from jax._src import dtypes
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src.config import config
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
@ -161,7 +161,7 @@ def _wraps(
try:
mod = module or fun.__module__
except AttributeError:
if config.jax_enable_checks:
if config.enable_checks.value:
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
else:
name = f"{mod}.{name}"
@ -206,7 +206,7 @@ def _wraps(
if kept_sections:
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
except:
if config.jax_enable_checks:
if config.enable_checks.value:
raise
docstr = fun.__doc__
@ -229,7 +229,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
# TODO(mattjj): revise this workaround
@ -242,7 +242,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
if len(nonscalar_ranks) < 2:
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
if config.numpy_rank_promotion.value != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
@ -250,13 +250,13 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
if config.jax_numpy_rank_promotion == "warn":
if config.numpy_rank_promotion.value == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise":
elif config.numpy_rank_promotion.value == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "

View File

@ -25,6 +25,7 @@ import warnings
import numpy as np
from jax._src import config
from jax._src import core
from jax._src import stages
from jax._src import dispatch
@ -46,7 +47,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import xla
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -182,7 +182,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
@wraps(fun)
@api_boundary
def wrapped(*args, **kwargs):
if config.jax_disable_jit:
if config.disable_jit.value:
return fun(*args, **kwargs)
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
@ -275,7 +275,7 @@ def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames,
static_argnums, static_argnames, device,
backend, abstracted_axes):
if abstracted_axes and not config.jax_dynamic_shapes:
if abstracted_axes and not config.dynamic_shapes.value:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
check_callable(fun)
@ -432,7 +432,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
dyn_kwargs = {}
del kwargs
if (donate_argnums or donate_argnames) and not config.jax_debug_nans:
if (donate_argnums or donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector(
donate_argnums, donate_argnames, dyn_args, dyn_kwargs)
else:
@ -462,7 +462,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
assert in_shardings is not None or all(i is not None for i in in_shardings)
assert out_shardings is not None or all(o is not None for o in out_shardings)
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)
else:
@ -490,7 +490,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
implicit_args = _extract_implicit_args(in_type, explicit_args)
else:
implicit_args = []
@ -892,7 +892,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
"pjit in_shardings", in_tree, orig_in_shardings,
tupled_args=True)
if not config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.arg_names,
"pjit arguments", allow_uneven_sharding=False)
@ -909,14 +909,14 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, in_type), debug_info=pe_debug)
else:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=pe_debug)
if not config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
if any(isinstance(c, core.Tracer) for c in consts):
@ -944,7 +944,7 @@ def _check_and_canonicalize_out_shardings(
"pjit out_shardings", out_tree(), orig_out_shardings,
tupled_args=False)
if not config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(
out_shardings_flat, out_type,
None if debug_info is None else debug_info.result_paths,
@ -1132,10 +1132,10 @@ def _pjit_call_impl_python(
lowering_parameters=mlir.LoweringParameters()).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks:
if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
jaxpr.jaxpr.debug_info)
if config.jax_distributed_debug:
if config.distributed_debug.value:
# Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this
fingerprint = None
@ -1151,7 +1151,7 @@ def _pjit_call_impl_python(
try:
return compiled.unsafe_call(*args), compiled
except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
@ -1159,7 +1159,7 @@ def _pjit_call_impl_python(
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
msg = ("An invalid value was encountered in the output of the "
f"`jit`-decorated function {name}. Because "
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
@ -1313,7 +1313,7 @@ def pjit_staging_rule(trace, *args, **params):
jaxpr = params['jaxpr']
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
propagate_source_info=False)
elif config.jax_dynamic_shapes:
elif config.dynamic_shapes.value:
source_info = source_info_util.current()
out_tracers = []
for aval in _out_type(params['jaxpr']):

View File

@ -31,7 +31,7 @@ from jax import tree_util
from jax._src import ad_util
from jax._src import api
from jax._src import basearray
from jax._src import config as config_lib
from jax._src import config as config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -40,7 +40,6 @@ from jax._src import sharding_specs
from jax._src import tree_util as tree_util_internal
from jax._src import typing
from jax._src.api import jit, vmap
from jax._src.config import config
from jax._src.dtypes import float0
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -969,7 +968,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array:
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
k1 = convert(
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
with config_lib.numpy_dtype_promotion('standard'):
with config.numpy_dtype_promotion('standard'):
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
# inputs. We should avoid this.
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
@ -1270,7 +1269,7 @@ def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
@partial(jit, static_argnums=(1,))
def _threefry_split(key, shape) -> typing.Array:
if config.jax_threefry_partitionable:
if config.threefry_partitionable.value:
return _threefry_split_foldlike(key, shape) # type: ignore
else:
return _threefry_split_original(key, shape) # type: ignore
@ -1305,7 +1304,7 @@ def threefry_random_bits(key: typing.Array, bit_width, shape):
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
if config.jax_threefry_partitionable:
if config.threefry_partitionable.value:
return _threefry_random_bits_partitionable(key, bit_width, shape)
else:
return _threefry_random_bits_original(key, bit_width, shape)
@ -1398,7 +1397,7 @@ def _rbg_seed(seed: typing.Array) -> typing.Array:
return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
if config.jax_threefry_partitionable:
if config.threefry_partitionable.value:
_threefry_split = _threefry_split_foldlike
else:
_threefry_split = _threefry_split_original

View File

@ -16,8 +16,8 @@ from functools import partial
import operator
from jax._src import api
from jax._src import config
from jax._src import dtypes as _dtypes
from jax._src.config import config
from jax._src.tree_util import tree_map, tree_reduce
import numpy as np
@ -130,7 +130,7 @@ def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
def _check_dtypes_match(xs, ys):
def _assert_dtypes_match(x, y):
if config.x64_enabled:
if config.enable_x64.value:
assert _dtype(x) == _dtype(y)
else:
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==

View File

@ -26,13 +26,12 @@ import jax.numpy as jnp
from jax import lax
from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import config as config_lib
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
from jax._src.api import jit, vmap
from jax._src.config import config
from jax._src.core import NamedShape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -77,17 +76,17 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
elif _arraylike(key):
# Call random_wrap here to surface errors for invalid keys.
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
if config.jax_legacy_prng_key == 'error':
if config.legacy_prng_key.value == 'error':
raise ValueError(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".')
elif config.jax_legacy_prng_key == 'warn':
elif config.legacy_prng_key.value == 'warn':
warnings.warn(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
elif config.jax_enable_custom_prng:
elif config.enable_custom_prng.value:
# TODO(jakevdp): possibly remove this warning condition.
warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. '
@ -101,7 +100,7 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if config.jax_enable_custom_prng:
if config.enable_custom_prng.value:
return key
else:
return prng.random_unwrap(key) if was_wrapped else key
@ -120,7 +119,7 @@ def default_prng_impl():
The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name.
"""
impl_name = config.jax_default_prng_impl
impl_name = config.default_prng_impl.value
assert impl_name in prng.prngs, impl_name
return prng.prngs[impl_name]
@ -225,8 +224,8 @@ def PRNGKey(seed: Union[int, Array], *,
# TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name):
default_impl = default_prng_impl()
default_name = config.jax_default_prng_impl
if not config.jax_enable_custom_prng and default_impl is not impl:
default_name = config.default_prng_impl.value
if not config.enable_custom_prng.value and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
f'to seed an RNG with an implementation "{name}" '
f'differing from the default "{default_name}".')
@ -829,7 +828,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
else: # 'cholesky'
factor = cholesky(cov)
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
with config_lib.numpy_rank_promotion('allow'):
with config.numpy_rank_promotion('allow'):
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return result

View File

@ -24,11 +24,11 @@ import numpy as np
from jax._src import api_util
from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import tree_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
@ -262,7 +262,7 @@ run_state_p.multiple_results = True
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...]):
if config.jax_enable_checks:
if config.enable_checks.value:
core.check_jaxpr(jaxpr)
assert len(jaxpr.invars) == len(args)
assert len(which_linear) == len(args)

View File

@ -27,7 +27,7 @@ from typing import Any, Callable
import jax
from jax import core
from jax._src.config import config
from jax._src import config
from jax._src.lib import tpu_mosaic
from jax._src.lib import xla_client
from jax.interpreters import mlir
@ -38,7 +38,7 @@ from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
import numpy as np
config.define_bool_state(
mosaic_use_cpp_passes = config.config.define_bool_state(
name="mosaic_use_cpp_passes",
default=False,
help=(
@ -54,13 +54,13 @@ tpu = tpu_mosaic.tpu
apply_vector_layout = tpu_mosaic.apply_vector_layout
infer_memref_layout = tpu_mosaic.infer_memref_layout
config.define_bool_state(
mosaic_allow_hlo = config.config.define_bool_state(
name="jax_mosaic_allow_hlo",
default=False,
help="Allow hlo dialects in Mosaic",
)
config.define_bool_state(
mosaic_dump_mlir = config.config.define_bool_state(
name="jax_mosaic_dump_mlir",
default=False,
help="Print mlir module after each pass",
@ -243,7 +243,7 @@ def _lower_tpu_kernel(
)
dump_mlir(module, "initial module")
if config.jax_mosaic_allow_hlo:
if mosaic_allow_hlo.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
@ -255,7 +255,7 @@ def _lower_tpu_kernel(
)
dump_mlir(module, "after hlo conversion module")
if config.mosaic_use_cpp_passes:
if mosaic_use_cpp_passes.value:
pipeline = [
(
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
@ -278,7 +278,7 @@ def _lower_tpu_kernel(
module.operation.verify()
dump_mlir(module, "after infer vector layout pass")
if config.mosaic_use_cpp_passes:
if mosaic_use_cpp_passes.value:
pipeline = [
(
"func.func(tpu-apply-vector-layout{sublane-count=8"
@ -418,6 +418,6 @@ def _lowered_as_tpu_kernel(
def dump_mlir(module: ir.Module, msg: str):
"""A helper function to print mlir module with a message."""
if config.jax_mosaic_dump_mlir:
if mosaic_dump_mlir.value:
print(f"[jax_mosaic_dump_mlir] {msg}")
print(module)

View File

@ -19,9 +19,9 @@ import traceback
import types
from typing import Any, Callable, Optional, TypeVar, cast
from jax._src.config import config
from jax._src.lib import xla_extension
from jax._src import config
from jax._src import util
from jax._src.lib import xla_extension
C = TypeVar("C", bound=Callable[..., Any])
@ -139,7 +139,7 @@ def _ipython_supports_tracebackhide() -> bool:
return IPython.version_info[:2] >= (7, 17)
def _filtering_mode() -> str:
mode = config.jax_traceback_filtering
mode = config.traceback_filtering.value
if mode is None or mode == "auto":
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide"

View File

@ -23,9 +23,9 @@ import weakref
import numpy as np
from jax._src import config
from jax._src.lib import xla_client as xc
from jax._src.lib import utils as jaxlib_utils
from jax._src.config import config
logger = logging.getLogger(__name__)
@ -257,10 +257,10 @@ def cache(max_size=4096):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if config.jax_check_tracer_leaks:
if config.check_tracer_leaks.value:
return f(*args, **kwargs)
else:
return cached(config._trace_context(), *args, **kwargs)
return cached(config.config._trace_context(), *args, **kwargs)
wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
@ -278,7 +278,7 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
behave similar to `functools.lru_cache`.
"""
global _weakref_lru_caches
cached_call = xc.weakref_lru_cache(config._trace_context, call, maxsize)
cached_call = xc.weakref_lru_cache(config.config._trace_context, call, maxsize)
_weakref_lru_caches.add(cached_call)
return cached_call
@ -464,7 +464,7 @@ def distributed_debug_log(*pairs):
pairs: A sequence of label/value pairs to log. The first pair is treated as
a heading for subsequent pairs.
"""
if config.jax_distributed_debug:
if config.distributed_debug.value:
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
try:
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")

View File

@ -35,9 +35,8 @@ import threading
from typing import Any, Callable, Optional, Union
import warnings
from jax._src import config
from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import config
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
@ -63,34 +62,34 @@ XlaBackend = xla_client.Client
# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = jax_config.DEFINE_string(
_XLA_BACKEND = config.DEFINE_string(
'jax_xla_backend', '',
'Deprecated, please use --jax_platforms instead.')
BACKEND_TARGET = jax_config.DEFINE_string(
BACKEND_TARGET = config.DEFINE_string(
'jax_backend_target',
os.getenv('JAX_BACKEND_TARGET', '').lower(),
'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
_PLATFORM_NAME = jax_config.DEFINE_string(
_PLATFORM_NAME = config.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.')
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
CUDA_VISIBLE_DEVICES = config.DEFINE_string(
'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
_ROCM_VISIBLE_DEVICES = config.DEFINE_string(
'jax_rocm_visible_devices', 'all',
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
_USE_MOCK_GPU_CLIENT = jax_config.DEFINE_bool(
_USE_MOCK_GPU_CLIENT = config.DEFINE_bool(
name="use_mock_gpu_client",
default=False,
help="If True, use a mock GPU client instead of a real one.",
)
_MOCK_NUM_GPUS = jax_config.DEFINE_integer(
_MOCK_NUM_GPUS = config.DEFINE_integer(
name="mock_num_gpus",
default=1,
help="Mock GPU client number of gpus.",
@ -223,7 +222,7 @@ def _check_cuda_versions():
def make_gpu_client(
*, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
*, platform_name: str, visible_devices_flag: config.FlagHolder[str]
) -> xla_client.Client:
visible_devices = visible_devices_flag.value
allowed_devices = None
@ -564,11 +563,10 @@ def backends() -> dict[str, xla_client.Client]:
with _backend_lock:
if _backends:
return _backends
if config.jax_platforms:
jax_platforms = config.jax_platforms.split(",")
if jax_platforms := config.jax_platforms.value:
platforms = []
# Allow platform aliases in the list of platforms.
for platform in jax_platforms:
for platform in jax_platforms.split(","):
platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1)
# If the user specified a list of platforms explicitly, always fail
@ -597,14 +595,14 @@ def backends() -> dict[str, xla_client.Client]:
_backend_errors[platform] = str(err)
logger.info(err_msg)
else:
if config.jax_platforms:
if config.jax_platforms.value:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
else:
err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
raise RuntimeError(err_msg)
assert _default_backend is not None
if not config.jax_platforms:
if not config.jax_platforms.value:
_suggest_missing_backends()
return _backends

View File

@ -55,6 +55,7 @@ import numpy as np
import jax
from jax import lax
from jax._src import config
from jax._src import core
from jax._src.custom_derivatives import lift_jvp
from jax._src import linear_util as lu
@ -67,7 +68,6 @@ from jax._src.lib import pytree
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config
from jax._src.lax.control_flow import _check_tree_and_avals
from jax._src.numpy import lax_numpy
from jax.experimental import sparse
@ -614,7 +614,7 @@ def _add_sparse(spenv, *spvalues):
raise NotImplementedError("Addition between sparse matrices of different shapes.")
if X.indices_ref == Y.indices_ref:
out_data = lax.add(spenv.data(X), spenv.data(Y))
if config.jax_enable_checks:
if config.enable_checks.value:
assert X.indices_sorted == Y.indices_sorted
assert X.unique_indices == Y.unique_indices
out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref,
@ -657,7 +657,7 @@ def _mul_sparse(spenv, *spvalues):
X, Y = spvalues
if X.is_sparse() and Y.is_sparse():
if X.indices_ref == Y.indices_ref and X.unique_indices:
if config.jax_enable_checks:
if config.enable_checks.value:
assert X.indices_sorted == Y.indices_sorted
assert X.unique_indices == Y.unique_indices
out_data = lax.mul(spenv.data(X), spenv.data(Y))