mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. PiperOrigin-RevId: 571932143
This commit is contained in:
parent
14414363d0
commit
65d3058944
@ -23,6 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
@ -31,7 +32,6 @@ from jax._src import source_info_util
|
|||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
@ -529,7 +529,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
|||||||
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
||||||
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
||||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
||||||
logger.log(logging.WARNING if config.jax_log_checkpoint_residuals
|
logger.log(logging.WARNING if config.log_checkpoint_residuals.value
|
||||||
else logging.DEBUG,
|
else logging.DEBUG,
|
||||||
'remat-decorated function ' +
|
'remat-decorated function ' +
|
||||||
'saving inputs with shapes:\n' * bool(res_invars) +
|
'saving inputs with shapes:\n' * bool(res_invars) +
|
||||||
@ -659,7 +659,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
|||||||
assert not jaxpr.constvars
|
assert not jaxpr.constvars
|
||||||
|
|
||||||
if differentiated and prevent_cse:
|
if differentiated and prevent_cse:
|
||||||
if config.jax_remat_opt_barrier:
|
if config.remat_opt_barrier.value:
|
||||||
translation_rule = _remat_translation_using_opt_barrier
|
translation_rule = _remat_translation_using_opt_barrier
|
||||||
elif is_gpu_platform:
|
elif is_gpu_platform:
|
||||||
translation_rule = _remat_translation_using_while
|
translation_rule = _remat_translation_using_while
|
||||||
|
@ -27,13 +27,13 @@ from jax._src import abstract_arrays
|
|||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
from jax._src import basearray
|
from jax._src import basearray
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import profiler
|
from jax._src import profiler
|
||||||
from jax._src import tree_util
|
from jax._src import tree_util
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
@ -172,7 +172,7 @@ class ArrayImpl(basearray.Array):
|
|||||||
# input buffers are already arranged properly. This usually happens when
|
# input buffers are already arranged properly. This usually happens when
|
||||||
# Array's are created as output of a JAX transformation
|
# Array's are created as output of a JAX transformation
|
||||||
# (like pjit, xmap, etc).
|
# (like pjit, xmap, etc).
|
||||||
if not _skip_checks or config.jax_enable_checks:
|
if not _skip_checks or config.enable_checks.value:
|
||||||
self._check_and_rearrange()
|
self._check_and_rearrange()
|
||||||
|
|
||||||
def _check_and_rearrange(self):
|
def _check_and_rearrange(self):
|
||||||
|
@ -20,7 +20,7 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from jax._src.config import config
|
from jax._src import config
|
||||||
from jax._src.lib import version as jaxlib_version
|
from jax._src.lib import version as jaxlib_version
|
||||||
from jax._src.lib import version_str as jaxlib_version_str
|
from jax._src.lib import version_str as jaxlib_version_str
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
@ -167,7 +167,7 @@ def _canonicalize_ir(m_original: ir.Module) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def _hash_computation(hash_obj, module):
|
def _hash_computation(hash_obj, module):
|
||||||
if config.jax_compilation_cache_include_metadata_in_key:
|
if config.compilation_cache_include_metadata_in_key.value:
|
||||||
canonical_ir = _serialize_ir(module)
|
canonical_ir = _serialize_ir(module)
|
||||||
else:
|
else:
|
||||||
canonical_ir = _canonicalize_ir(module)
|
canonical_ir = _canonicalize_ir(module)
|
||||||
|
@ -27,6 +27,7 @@ from jax import lax
|
|||||||
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import custom_derivatives
|
from jax._src import custom_derivatives
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
@ -37,7 +38,6 @@ from jax._src import traceback_util
|
|||||||
from jax._src import tree_util as jtu
|
from jax._src import tree_util as jtu
|
||||||
from jax._src.ad_util import SymbolicZero
|
from jax._src.ad_util import SymbolicZero
|
||||||
from jax._src.api_util import flatten_fun
|
from jax._src.api_util import flatten_fun
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
@ -511,7 +511,7 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
|
|||||||
if debug:
|
if debug:
|
||||||
# NOOP (check will only trigger when discharged)
|
# NOOP (check will only trigger when discharged)
|
||||||
return []
|
return []
|
||||||
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
if not config.xla_runtime_errors.value:
|
||||||
raise functionalization_error
|
raise functionalization_error
|
||||||
|
|
||||||
out_op, _, _ = mlir.emit_python_callback(
|
out_op, _, _ = mlir.emit_python_callback(
|
||||||
|
@ -30,32 +30,31 @@ import numpy as np
|
|||||||
|
|
||||||
from jax._src import lib
|
from jax._src import lib
|
||||||
from jax._src import compilation_cache
|
from jax._src import compilation_cache
|
||||||
from jax._src import config as jax_config
|
from jax._src import config as config
|
||||||
from jax._src import monitoring
|
from jax._src import monitoring
|
||||||
from jax._src import path
|
from jax._src import path
|
||||||
from jax._src import profiler
|
from jax._src import profiler
|
||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
from jax._src.lib import xla_extension_version
|
||||||
|
|
||||||
|
|
||||||
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
|
_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool(
|
||||||
'jax_disable_most_optimizations',
|
'jax_disable_most_optimizations',
|
||||||
jax_config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||||
'Try not to do much optimization work. This can be useful if the cost of '
|
'Try not to do much optimization work. This can be useful if the cost of '
|
||||||
'optimization is greater than that of running a less-optimized program.')
|
'optimization is greater than that of running a less-optimized program.')
|
||||||
|
|
||||||
_DUMP_IR_TO = jax_config.DEFINE_string(
|
_DUMP_IR_TO = config.DEFINE_string(
|
||||||
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
||||||
help="Path to which the IR that is emitted by JAX as input to the "
|
help="Path to which the IR that is emitted by JAX as input to the "
|
||||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||||
"will not dump IR.")
|
"will not dump IR.")
|
||||||
|
|
||||||
_COMPILER_DETAILED_LOGGING_MIN_OPS = jax_config.DEFINE_integer(
|
_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
|
||||||
"jax_compiler_detailed_logging_min_ops",
|
"jax_compiler_detailed_logging_min_ops",
|
||||||
jax_config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
|
config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
|
||||||
help=(
|
help=(
|
||||||
'How big should a module be in MLIR operations before JAX enables '
|
'How big should a module be in MLIR operations before JAX enables '
|
||||||
'detailed compiler logging? The intent of this flag is to suppress '
|
'detailed compiler logging? The intent of this flag is to suppress '
|
||||||
@ -192,7 +191,7 @@ def get_compile_options(
|
|||||||
# If the function returns 0, set -1; this is an error.
|
# If the function returns 0, set -1; this is an error.
|
||||||
# -1 indicates that no attempt should be made to retrieve the latest profile
|
# -1 indicates that no attempt should be made to retrieve the latest profile
|
||||||
# later on.
|
# later on.
|
||||||
jax_xla_profile_version = config.jax_xla_profile_version
|
jax_xla_profile_version = config.jax_xla_profile_version.value
|
||||||
if jax_xla_profile_version > 0:
|
if jax_xla_profile_version > 0:
|
||||||
compile_options.profile_version = jax_xla_profile_version
|
compile_options.profile_version = jax_xla_profile_version
|
||||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||||
@ -306,7 +305,7 @@ def compile_or_get_cached(
|
|||||||
try:
|
try:
|
||||||
cache_key = compilation_cache.get_cache_key(
|
cache_key = compilation_cache.get_cache_key(
|
||||||
computation, devices, compile_options, backend,
|
computation, devices, compile_options, backend,
|
||||||
jax_config.config.jax_use_original_compilation_cache_key_generation,
|
config.config.jax_use_original_compilation_cache_key_generation,
|
||||||
)
|
)
|
||||||
except xc._xla.XlaRuntimeError as ex:
|
except xc._xla.XlaRuntimeError as ex:
|
||||||
logger.error("compile_or_get_cached: unable to generate cache key, "
|
logger.error("compile_or_get_cached: unable to generate cache key, "
|
||||||
@ -325,7 +324,7 @@ def compile_or_get_cached(
|
|||||||
|
|
||||||
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
|
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
|
||||||
# rate after it is enabled.
|
# rate after it is enabled.
|
||||||
if jax_config.config.jax_use_original_compilation_cache_key_generation:
|
if config.config.jax_use_original_compilation_cache_key_generation:
|
||||||
# TODO(b/293308239) Remove metrics for the original cache after the new
|
# TODO(b/293308239) Remove metrics for the original cache after the new
|
||||||
# compilation cache key implementation is fully rolled out.
|
# compilation cache key implementation is fully rolled out.
|
||||||
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
|
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
|
||||||
@ -358,7 +357,7 @@ def _cache_read(
|
|||||||
return compilation_cache.get_executable_and_time(
|
return compilation_cache.get_executable_and_time(
|
||||||
cache_key, compile_options, backend)
|
cache_key, compile_options, backend)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if config.jax_raise_persistent_cache_errors:
|
if config.raise_persistent_cache_errors.value:
|
||||||
raise
|
raise
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Error reading persistent compilation cache entry for "
|
f"Error reading persistent compilation cache entry for "
|
||||||
@ -380,7 +379,7 @@ def _cache_write(cache_key: str,
|
|||||||
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
|
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
|
min_compile_time = config.persistent_cache_min_compile_time_secs.value
|
||||||
if min_compile_time:
|
if min_compile_time:
|
||||||
if compile_time_secs < min_compile_time:
|
if compile_time_secs < min_compile_time:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -399,7 +398,7 @@ def _cache_write(cache_key: str,
|
|||||||
compilation_cache.put_executable_and_time(
|
compilation_cache.put_executable_and_time(
|
||||||
cache_key, module_name, executable, backend, int(compile_time_secs))
|
cache_key, module_name, executable, backend, int(compile_time_secs))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if config.jax_raise_persistent_cache_errors:
|
if config.raise_persistent_cache_errors.value:
|
||||||
raise
|
raise
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Error writing persistent compilation cache entry for "
|
f"Error writing persistent compilation cache entry for "
|
||||||
|
@ -22,7 +22,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar
|
from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
|
||||||
|
|
||||||
from jax._src import lib
|
from jax._src import lib
|
||||||
from jax._src.lib import jax_jit
|
from jax._src.lib import jax_jit
|
||||||
@ -75,6 +75,12 @@ class FlagHolder(Generic[_T]):
|
|||||||
self._flags = flags
|
self._flags = flags
|
||||||
self._name = name
|
self._name = name
|
||||||
|
|
||||||
|
def __bool__(self) -> NoReturn:
|
||||||
|
raise TypeError(
|
||||||
|
"bool() not supported for instances of type '{0}' "
|
||||||
|
"(did you mean to use '{0}.value' instead?)".format(
|
||||||
|
type(self).__name__))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value(self) -> _T:
|
def value(self) -> _T:
|
||||||
return getattr(self._flags, self._name)
|
return getattr(self._flags, self._name)
|
||||||
@ -523,6 +529,12 @@ class _StateContextManager(Generic[_T]):
|
|||||||
self._validate_new_val_hook = validate_new_val_hook
|
self._validate_new_val_hook = validate_new_val_hook
|
||||||
self._default_value = default_value
|
self._default_value = default_value
|
||||||
|
|
||||||
|
def __bool__(self) -> NoReturn:
|
||||||
|
raise TypeError(
|
||||||
|
"bool() not supported for instances of type '{0}' "
|
||||||
|
"(did you mean to use '{0}.value' instead?)".format(
|
||||||
|
type(self).__name__))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value(self) -> _T:
|
def value(self) -> _T:
|
||||||
val = _thread_local_state.__dict__.get(self._name, unset)
|
val = _thread_local_state.__dict__.get(self._name, unset)
|
||||||
@ -935,7 +947,7 @@ use_original_compilation_cache_key_generation = config.define_bool_state(
|
|||||||
"deployed, this flag and the original cache-key generation algorithm "
|
"deployed, this flag and the original cache-key generation algorithm "
|
||||||
"will be removed.")
|
"will be removed.")
|
||||||
|
|
||||||
config.define_enum_state(
|
default_dtype_bits = config.define_enum_state(
|
||||||
name='jax_default_dtype_bits',
|
name='jax_default_dtype_bits',
|
||||||
enum_values=['32', '64'],
|
enum_values=['32', '64'],
|
||||||
default='64',
|
default='64',
|
||||||
@ -1090,7 +1102,7 @@ bcoo_cusparse_lowering = config.define_bool_state(
|
|||||||
|
|
||||||
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
|
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
|
||||||
# if the intended backend can handle lowering the result
|
# if the intended backend can handle lowering the result
|
||||||
config.define_bool_state(
|
dynamic_shapes = config.define_bool_state(
|
||||||
name='jax_dynamic_shapes',
|
name='jax_dynamic_shapes',
|
||||||
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
|
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
|
||||||
help=('Enables experimental features for staging out computations with '
|
help=('Enables experimental features for staging out computations with '
|
||||||
@ -1102,19 +1114,19 @@ config.define_bool_state(
|
|||||||
|
|
||||||
# This flag is temporary during rollout of the remat barrier.
|
# This flag is temporary during rollout of the remat barrier.
|
||||||
# TODO(parkers): Remove if there are no complaints.
|
# TODO(parkers): Remove if there are no complaints.
|
||||||
config.define_bool_state(
|
remat_opt_barrier = config.define_bool_state(
|
||||||
name='jax_remat_opt_barrier',
|
name='jax_remat_opt_barrier',
|
||||||
default=(lib.version >= (0, 3, 6)),
|
default=(lib.version >= (0, 3, 6)),
|
||||||
help=('Enables using optimization-barrier op for lowering remat.'))
|
help=('Enables using optimization-barrier op for lowering remat.'))
|
||||||
|
|
||||||
# TODO(sharadmv,mattjj): set default to True, then remove
|
# TODO(sharadmv,mattjj): set default to True, then remove
|
||||||
config.define_bool_state(
|
eager_pmap = config.define_bool_state(
|
||||||
name='jax_eager_pmap',
|
name='jax_eager_pmap',
|
||||||
default=True,
|
default=True,
|
||||||
upgrade=True,
|
upgrade=True,
|
||||||
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
||||||
|
|
||||||
config.define_bool_state(
|
xla_runtime_errors = config.define_bool_state(
|
||||||
name='jax_experimental_unsafe_xla_runtime_errors',
|
name='jax_experimental_unsafe_xla_runtime_errors',
|
||||||
default=False,
|
default=False,
|
||||||
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
|
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
|
||||||
|
@ -36,9 +36,8 @@ from weakref import ref
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import config as jax_config
|
from jax._src import config
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.errors import (
|
from jax._src.errors import (
|
||||||
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
||||||
TracerIntegerConversionError, UnexpectedTracerError)
|
TracerIntegerConversionError, UnexpectedTracerError)
|
||||||
@ -60,9 +59,9 @@ zip, unsafe_zip = safe_zip, zip
|
|||||||
map, unsafe_map = safe_map, map
|
map, unsafe_map = safe_map, map
|
||||||
|
|
||||||
|
|
||||||
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer(
|
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer(
|
||||||
'jax_tracer_error_num_traceback_frames',
|
'jax_tracer_error_num_traceback_frames',
|
||||||
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
||||||
help='Set the number of stack frames in JAX tracer error messages.'
|
help='Set the number of stack frames in JAX tracer error messages.'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -269,7 +268,7 @@ class JaxprEqn(NamedTuple):
|
|||||||
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
|
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
|
||||||
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
|
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
|
||||||
source_info = source_info or source_info_util.new_source_info()
|
source_info = source_info or source_info_util.new_source_info()
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
assert all(isinstance(x, (Var, Literal)) for x in invars)
|
assert all(isinstance(x, (Var, Literal)) for x in invars)
|
||||||
assert all(isinstance(v, Var) for v in outvars)
|
assert all(isinstance(v, Var) for v in outvars)
|
||||||
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
|
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
|
||||||
@ -381,7 +380,7 @@ class Primitive:
|
|||||||
return f'{self.name}'
|
return f'{self.name}'
|
||||||
|
|
||||||
def bind(self, *args, **params):
|
def bind(self, *args, **params):
|
||||||
assert (not config.jax_enable_checks or
|
assert (not config.enable_checks.value or
|
||||||
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
|
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
|
||||||
return self.bind_with_trace(find_top_trace(args), args, params)
|
return self.bind_with_trace(find_top_trace(args), args, params)
|
||||||
|
|
||||||
@ -438,7 +437,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
|
|||||||
return v.val if isinstance(v, Literal) else env[v]
|
return v.val if isinstance(v, Literal) else env[v]
|
||||||
|
|
||||||
def write(v: Var, val: Any) -> None:
|
def write(v: Var, val: Any) -> None:
|
||||||
if config.jax_enable_checks and not config.jax_dynamic_shapes:
|
if config.enable_checks.value and not config.dynamic_shapes.value:
|
||||||
assert typecheck(v.aval, val), (v.aval, val)
|
assert typecheck(v.aval, val), (v.aval, val)
|
||||||
env[v] = val
|
env[v] = val
|
||||||
|
|
||||||
@ -739,7 +738,7 @@ class Tracer(typing.Array):
|
|||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
# if the aval property raises an AttributeError, gets caught here
|
# if the aval property raises an AttributeError, gets caught here
|
||||||
assert not config.jax_enable_checks or name != "aval"
|
assert not config.enable_checks.value or name != "aval"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attr = getattr(self.aval, name)
|
attr = getattr(self.aval, name)
|
||||||
@ -989,7 +988,7 @@ def _update_thread_local_jit_state(dynamic):
|
|||||||
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
|
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
|
||||||
# alive by the JIT cache, particularly for nested JIT-ted functions.
|
# alive by the JIT cache, particularly for nested JIT-ted functions.
|
||||||
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
|
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
|
||||||
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
|
config.update_thread_local_jit_state(dynamic_trace_state=copy)
|
||||||
|
|
||||||
|
|
||||||
# The global state of the tracer is accessed by a thread-local object.
|
# The global state of the tracer is accessed by a thread-local object.
|
||||||
@ -1015,7 +1014,7 @@ def _initialize_jax_jit_thread_local_state():
|
|||||||
if tls.extra_jit_context is None:
|
if tls.extra_jit_context is None:
|
||||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||||
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
|
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
|
||||||
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
|
config.update_thread_local_jit_state(dynamic_trace_state=copy)
|
||||||
|
|
||||||
|
|
||||||
jax_jit.set_thread_local_state_initialization_callback(
|
jax_jit.set_thread_local_state_initialization_callback(
|
||||||
@ -1151,7 +1150,7 @@ def new_main(trace_type: type[Trace], dynamic: bool = False,
|
|||||||
stack.dynamic = prev_dynamic
|
stack.dynamic = prev_dynamic
|
||||||
_update_thread_local_jit_state(stack.dynamic)
|
_update_thread_local_jit_state(stack.dynamic)
|
||||||
|
|
||||||
if config.jax_check_tracer_leaks:
|
if config.check_tracer_leaks.value:
|
||||||
t = ref(main)
|
t = ref(main)
|
||||||
del main
|
del main
|
||||||
if t() is not None:
|
if t() is not None:
|
||||||
@ -1188,7 +1187,7 @@ def new_base_main(trace_type: type[Trace],
|
|||||||
stack.stack[0] = prev_base
|
stack.stack[0] = prev_base
|
||||||
_update_thread_local_jit_state(stack.dynamic)
|
_update_thread_local_jit_state(stack.dynamic)
|
||||||
|
|
||||||
if config.jax_check_tracer_leaks:
|
if config.check_tracer_leaks.value:
|
||||||
t = ref(main)
|
t = ref(main)
|
||||||
del main
|
del main
|
||||||
if t() is not None:
|
if t() is not None:
|
||||||
@ -1268,7 +1267,7 @@ def new_sublevel() -> Generator[None, None, None]:
|
|||||||
finally:
|
finally:
|
||||||
thread_local_state.trace_state.substack.pop()
|
thread_local_state.trace_state.substack.pop()
|
||||||
|
|
||||||
if config.jax_check_tracer_leaks:
|
if config.check_tracer_leaks.value:
|
||||||
t = ref(sublevel)
|
t = ref(sublevel)
|
||||||
del sublevel
|
del sublevel
|
||||||
if t() is not None:
|
if t() is not None:
|
||||||
@ -2026,9 +2025,9 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
|||||||
return operator.index(dim)
|
return operator.index(dim)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
type_error = e
|
type_error = e
|
||||||
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
|
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
|
||||||
return dim
|
return dim
|
||||||
elif (config.jax_dynamic_shapes and isinstance(dim, DArray) and
|
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
|
||||||
type(dim._aval.dtype) is bint and not dim._aval.shape):
|
type(dim._aval.dtype) is bint and not dim._aval.shape):
|
||||||
return dim
|
return dim
|
||||||
elif is_dim(dim):
|
elif is_dim(dim):
|
||||||
@ -2230,7 +2229,7 @@ class CallPrimitive(Primitive):
|
|||||||
new_params = dict(params)
|
new_params = dict(params)
|
||||||
jaxpr = new_params.pop('call_jaxpr')
|
jaxpr = new_params.pop('call_jaxpr')
|
||||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
|
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
|
||||||
return [subfun], new_params
|
return [subfun], new_params
|
||||||
|
|
||||||
@ -2458,14 +2457,14 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
|
|||||||
frame = AxisEnvFrame(axis_name, size, tag)
|
frame = AxisEnvFrame(axis_name, size, tag)
|
||||||
ts = thread_local_state.trace_state
|
ts = thread_local_state.trace_state
|
||||||
ts.axis_env.append(frame)
|
ts.axis_env.append(frame)
|
||||||
jax_config.update_thread_local_jit_state(
|
config.update_thread_local_jit_state(
|
||||||
axis_env_state=tuple(f for f in ts.axis_env
|
axis_env_state=tuple(f for f in ts.axis_env
|
||||||
if f.name is not no_axis_name))
|
if f.name is not no_axis_name))
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
ts.axis_env.pop()
|
ts.axis_env.pop()
|
||||||
jax_config.update_thread_local_jit_state(
|
config.update_thread_local_jit_state(
|
||||||
axis_env_state=tuple(f for f in ts.axis_env
|
axis_env_state=tuple(f for f in ts.axis_env
|
||||||
if f.name is not no_axis_name))
|
if f.name is not no_axis_name))
|
||||||
|
|
||||||
@ -2474,14 +2473,14 @@ def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
|
|||||||
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
|
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
|
||||||
ts = thread_local_state.trace_state
|
ts = thread_local_state.trace_state
|
||||||
ts.axis_env.extend(frames)
|
ts.axis_env.extend(frames)
|
||||||
jax_config.update_thread_local_jit_state(
|
config.update_thread_local_jit_state(
|
||||||
axis_env_state=tuple(f for f in ts.axis_env
|
axis_env_state=tuple(f for f in ts.axis_env
|
||||||
if f.name is not no_axis_name))
|
if f.name is not no_axis_name))
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
for _ in frames: ts.axis_env.pop()
|
for _ in frames: ts.axis_env.pop()
|
||||||
jax_config.update_thread_local_jit_state(
|
config.update_thread_local_jit_state(
|
||||||
axis_env_state=tuple(f for f in ts.axis_env
|
axis_env_state=tuple(f for f in ts.axis_env
|
||||||
if f.name is not no_axis_name))
|
if f.name is not no_axis_name))
|
||||||
|
|
||||||
@ -2493,12 +2492,12 @@ def stash_axis_env():
|
|||||||
# be raised.
|
# be raised.
|
||||||
ts = thread_local_state.trace_state
|
ts = thread_local_state.trace_state
|
||||||
prev_axis_env, ts.axis_env = ts.axis_env, []
|
prev_axis_env, ts.axis_env = ts.axis_env, []
|
||||||
jax_config.update_thread_local_jit_state(axis_env_state=())
|
config.update_thread_local_jit_state(axis_env_state=())
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
ts.axis_env = prev_axis_env
|
ts.axis_env = prev_axis_env
|
||||||
jax_config.update_thread_local_jit_state(
|
config.update_thread_local_jit_state(
|
||||||
axis_env_state=tuple(f for f in ts.axis_env
|
axis_env_state=tuple(f for f in ts.axis_env
|
||||||
if f.name is not no_axis_name))
|
if f.name is not no_axis_name))
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from functools import update_wrapper, reduce, partial
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Generic, Optional, TypeVar
|
from typing import Any, Callable, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import custom_api_util
|
from jax._src import custom_api_util
|
||||||
from jax._src.custom_transpose import custom_transpose
|
from jax._src.custom_transpose import custom_transpose
|
||||||
@ -28,7 +29,6 @@ from jax._src import traceback_util
|
|||||||
from jax._src.ad_util import (
|
from jax._src.ad_util import (
|
||||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.core import raise_to_shaped
|
from jax._src.core import raise_to_shaped
|
||||||
from jax._src.errors import UnexpectedTracerError
|
from jax._src.errors import UnexpectedTracerError
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
@ -590,7 +590,7 @@ class custom_vjp(Generic[ReturnValue]):
|
|||||||
raise AttributeError(msg)
|
raise AttributeError(msg)
|
||||||
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
||||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||||
if config.jax_enable_custom_vjp_by_custom_transpose:
|
if config.enable_custom_vjp_by_custom_transpose.value:
|
||||||
if self.nondiff_argnums:
|
if self.nondiff_argnums:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'nondiff_argnums not implemented for new custom_vjp')
|
'nondiff_argnums not implemented for new custom_vjp')
|
||||||
@ -1072,7 +1072,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
|||||||
"""
|
"""
|
||||||
flat_args, in_tree = tree_flatten(example_args)
|
flat_args, in_tree = tree_flatten(example_args)
|
||||||
in_avals = tuple(map(abstractify, flat_args))
|
in_avals = tuple(map(abstractify, flat_args))
|
||||||
if config.jax_check_tracer_leaks:
|
if config.check_tracer_leaks.value:
|
||||||
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
||||||
else:
|
else:
|
||||||
return _closure_convert_for_avals(fun, in_tree, in_avals)
|
return _closure_convert_for_avals(fun, in_tree, in_avals)
|
||||||
|
@ -30,6 +30,7 @@ import numpy as np
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax._src import basearray
|
from jax._src import basearray
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
@ -39,7 +40,6 @@ from jax._src import source_info_util
|
|||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src import xla_bridge as xb
|
from jax._src import xla_bridge as xb
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
@ -160,7 +160,7 @@ def xla_primitive_callable(
|
|||||||
lowering_parameters=mlir.LoweringParameters())
|
lowering_parameters=mlir.LoweringParameters())
|
||||||
compiled = computation.compile()
|
compiled = computation.compile()
|
||||||
if xla_extension_version >= 192:
|
if xla_extension_version >= 192:
|
||||||
if config.jax_disable_jit:
|
if config.disable_jit.value:
|
||||||
call = compiled.unsafe_call
|
call = compiled.unsafe_call
|
||||||
else:
|
else:
|
||||||
call = compiled.create_cpp_call_for_apply_primitive(out_tree())
|
call = compiled.create_cpp_call_for_apply_primitive(out_tree())
|
||||||
@ -262,7 +262,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
|
|||||||
if _on_exit:
|
if _on_exit:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
yield
|
yield
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
@ -395,7 +395,7 @@ def _initial_style_primitive_replicas(params: dict[str, Any]) -> int:
|
|||||||
default=1)
|
default=1)
|
||||||
|
|
||||||
def needs_check_special() -> bool:
|
def needs_check_special() -> bool:
|
||||||
return config.jax_debug_infs or config.jax_debug_nans
|
return config.debug_infs.value or config.debug_nans.value
|
||||||
|
|
||||||
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
|
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
|
||||||
if needs_check_special():
|
if needs_check_special():
|
||||||
@ -404,9 +404,9 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
|
|||||||
|
|
||||||
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||||
if dtypes.issubdtype(dtype, np.inexact):
|
if dtypes.issubdtype(dtype, np.inexact):
|
||||||
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
|
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
|
||||||
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
||||||
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
|
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
|
||||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ import os
|
|||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from jax._src import clusters
|
from jax._src import clusters
|
||||||
from jax._src.config import config
|
from jax._src import config
|
||||||
from jax._src.lib import xla_extension
|
from jax._src.lib import xla_extension
|
||||||
from jax._src.lib import xla_extension_version
|
from jax._src.lib import xla_extension_version
|
||||||
|
|
||||||
@ -63,8 +63,8 @@ class State:
|
|||||||
if local_device_ids:
|
if local_device_ids:
|
||||||
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
|
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
|
||||||
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
||||||
config.update("jax_cuda_visible_devices", visible_devices)
|
config.config.update("jax_cuda_visible_devices", visible_devices)
|
||||||
config.update("jax_rocm_visible_devices", visible_devices)
|
config.config.update("jax_rocm_visible_devices", visible_devices)
|
||||||
|
|
||||||
self.process_id = process_id
|
self.process_id = process_id
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ import warnings
|
|||||||
import ml_dtypes
|
import ml_dtypes
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax._src.config import config
|
from jax._src import config
|
||||||
from jax._src.typing import DType, DTypeLike
|
from jax._src.typing import DType, DTypeLike
|
||||||
|
|
||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
@ -59,7 +59,6 @@ class extended(np.generic):
|
|||||||
>>> jnp.issubdtype(key.dtype, dtypes.extended)
|
>>> jnp.issubdtype(key.dtype, dtypes.extended)
|
||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class prng_key(extended):
|
class prng_key(extended):
|
||||||
@ -75,7 +74,6 @@ class prng_key(extended):
|
|||||||
>>> jnp.issubdtype(key.dtype, dtypes.prng_key)
|
>>> jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedDType(metaclass=abc.ABCMeta):
|
class ExtendedDType(metaclass=abc.ABCMeta):
|
||||||
@ -139,12 +137,28 @@ _int4_dtypes = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Default types.
|
# Default types.
|
||||||
bool_: type = np.bool_
|
bool_ = np.bool_
|
||||||
int_: type = np.int32 if config.jax_default_dtype_bits == '32' else np.int64
|
int_: type[Any]
|
||||||
uint: type = np.uint32 if config.jax_default_dtype_bits == '32' else np.uint64
|
uint: type[Any]
|
||||||
float_: type = np.float32 if config.jax_default_dtype_bits == '32' else np.float64
|
float_: type[Any]
|
||||||
complex_: type = np.complex64 if config.jax_default_dtype_bits == '32' else np.complex128
|
complex_: type[Any]
|
||||||
_default_types: dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
|
if config.default_dtype_bits.value == '32':
|
||||||
|
int_ = np.int32
|
||||||
|
uint = np.uint32
|
||||||
|
float_ = np.float32
|
||||||
|
complex_ = np.complex64
|
||||||
|
else:
|
||||||
|
int_ = np.int64
|
||||||
|
uint = np.uint64
|
||||||
|
float_ = np.float64
|
||||||
|
complex_ = np.complex128
|
||||||
|
_default_types: dict[str, type[Any]] = {
|
||||||
|
'b': bool_,
|
||||||
|
'i': int_,
|
||||||
|
'u': uint,
|
||||||
|
'f': float_,
|
||||||
|
'c': complex_,
|
||||||
|
}
|
||||||
|
|
||||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||||
float0: np.dtype = np.dtype([('float0', np.void, 0)])
|
float0: np.dtype = np.dtype([('float0', np.void, 0)])
|
||||||
@ -219,7 +233,7 @@ def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opa
|
|||||||
"allow_opaque_dtype argument is deprecated; use allow_extended_dtype.",
|
"allow_opaque_dtype argument is deprecated; use allow_extended_dtype.",
|
||||||
DeprecationWarning)
|
DeprecationWarning)
|
||||||
allow_extended_dtype = allow_opaque_dtype
|
allow_extended_dtype = allow_opaque_dtype
|
||||||
return _canonicalize_dtype(config.x64_enabled, allow_extended_dtype, dtype) # type: ignore[bad-return-type]
|
return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # type: ignore[bad-return-type]
|
||||||
|
|
||||||
# Default dtypes corresponding to Python scalars.
|
# Default dtypes corresponding to Python scalars.
|
||||||
python_scalar_dtypes : dict[type, DType] = {
|
python_scalar_dtypes : dict[type, DType] = {
|
||||||
@ -507,7 +521,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
|
|||||||
if len(LUB) == 1:
|
if len(LUB) == 1:
|
||||||
return LUB.pop()
|
return LUB.pop()
|
||||||
elif len(LUB) == 0:
|
elif len(LUB) == 0:
|
||||||
if config.jax_numpy_dtype_promotion == 'strict':
|
if config.numpy_dtype_promotion.value == 'strict':
|
||||||
msg = (
|
msg = (
|
||||||
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
||||||
"promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting "
|
"promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting "
|
||||||
@ -553,7 +567,7 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
|
|||||||
# object identity, not object equality, due to the behavior of np.dtype.__eq__
|
# object identity, not object equality, due to the behavior of np.dtype.__eq__
|
||||||
a_tp = cast(JAXType, a if any(a is t for t in _weak_types) else np.dtype(a))
|
a_tp = cast(JAXType, a if any(a is t for t in _weak_types) else np.dtype(a))
|
||||||
b_tp = cast(JAXType, b if any(b is t for t in _weak_types) else np.dtype(b))
|
b_tp = cast(JAXType, b if any(b is t for t in _weak_types) else np.dtype(b))
|
||||||
return np.dtype(_least_upper_bound(config.jax_numpy_dtype_promotion, a_tp, b_tp))
|
return np.dtype(_least_upper_bound(config.numpy_dtype_promotion.value, a_tp, b_tp))
|
||||||
|
|
||||||
def is_weakly_typed(x: Any) -> bool:
|
def is_weakly_typed(x: Any) -> bool:
|
||||||
try:
|
try:
|
||||||
@ -604,17 +618,17 @@ def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
|
|||||||
# Trivial promotion case. This allows extended dtypes through.
|
# Trivial promotion case. This allows extended dtypes through.
|
||||||
out_dtype = dtypes[0]
|
out_dtype = dtypes[0]
|
||||||
out_weak_type = False
|
out_weak_type = False
|
||||||
elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
|
elif all(weak_types) and config.numpy_dtype_promotion.value != 'strict':
|
||||||
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
||||||
# counterparts and apply the weak type at the end. This avoids returning the
|
# counterparts and apply the weak type at the end. This avoids returning the
|
||||||
# incorrect result with non-canonical weak types (e.g. weak int16).
|
# incorrect result with non-canonical weak types (e.g. weak int16).
|
||||||
# TODO(jakevdp): explore removing this special case.
|
# TODO(jakevdp): explore removing this special case.
|
||||||
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
|
result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
|
||||||
*{_jax_type(dtype, False) for dtype in dtypes})
|
*{_jax_type(dtype, False) for dtype in dtypes})
|
||||||
out_dtype = dtype(result_type)
|
out_dtype = dtype(result_type)
|
||||||
out_weak_type = True
|
out_weak_type = True
|
||||||
else:
|
else:
|
||||||
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
|
result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
|
||||||
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
||||||
out_dtype = dtype(result_type)
|
out_dtype = dtype(result_type)
|
||||||
out_weak_type = any(result_type is t for t in _weak_types)
|
out_weak_type = any(result_type is t for t in _weak_types)
|
||||||
|
@ -29,6 +29,7 @@ from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import effects as effects_lib
|
from jax._src import effects as effects_lib
|
||||||
@ -38,7 +39,6 @@ from jax._src import sharding_impls
|
|||||||
from jax._src import source_info_util
|
from jax._src import source_info_util
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src import xla_bridge as xb
|
from jax._src import xla_bridge as xb
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
@ -316,9 +316,8 @@ register_constant_handler(core.Token, _token_constant_handler)
|
|||||||
|
|
||||||
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
|
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
|
||||||
source_file = frame.file_name
|
source_file = frame.file_name
|
||||||
if config.jax_hlo_source_file_canonicalization_regex:
|
if pattern := config.hlo_source_file_canonicalization_regex.value:
|
||||||
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
|
source_file = re.sub(pattern, '', source_file)
|
||||||
'', source_file)
|
|
||||||
return source_file
|
return source_file
|
||||||
|
|
||||||
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
||||||
@ -349,7 +348,7 @@ def _source_info_to_location(
|
|||||||
name_stack: source_info_util.NameStack) -> ir.Location:
|
name_stack: source_info_util.NameStack) -> ir.Location:
|
||||||
eqn_str = (f'{str(source_info.name_stack)}/'
|
eqn_str = (f'{str(source_info.name_stack)}/'
|
||||||
f'{core.str_eqn_compact(primitive.name, params)}')
|
f'{core.str_eqn_compact(primitive.name, params)}')
|
||||||
if config.jax_include_full_tracebacks_in_locations:
|
if config.include_full_tracebacks_in_locations.value:
|
||||||
if source_info.traceback is None:
|
if source_info.traceback is None:
|
||||||
loc = ir.Location.unknown()
|
loc = ir.Location.unknown()
|
||||||
else:
|
else:
|
||||||
@ -622,7 +621,7 @@ def sharded_aval(aval: core.AbstractValue,
|
|||||||
|
|
||||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||||
shape: core.Shape) -> tuple[int | Value, ...]:
|
shape: core.Shape) -> tuple[int | Value, ...]:
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
||||||
else:
|
else:
|
||||||
ctx = ctx.replace(
|
ctx = ctx.replace(
|
||||||
@ -774,7 +773,7 @@ def lower_jaxpr_to_module(
|
|||||||
host_callbacks: list[Any] = []
|
host_callbacks: list[Any] = []
|
||||||
|
|
||||||
dim_vars: Sequence[str]
|
dim_vars: Sequence[str]
|
||||||
if not config.jax_dynamic_shapes:
|
if not config.dynamic_shapes.value:
|
||||||
# Find the dimension variables
|
# Find the dimension variables
|
||||||
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
|
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
|
||||||
for d in aval.shape if not core.is_constant_dim(d)]
|
for d in aval.shape if not core.is_constant_dim(d)]
|
||||||
@ -1418,7 +1417,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
|||||||
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
|
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
|
||||||
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
||||||
tokens_out=None, dim_var_values=dim_var_values)
|
tokens_out=None, dim_var_values=dim_var_values)
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
axis_size_env = {d: read(d)[0]
|
axis_size_env = {d: read(d)[0]
|
||||||
for a in avals_in if type(a) is core.DShapedArray
|
for a in avals_in if type(a) is core.DShapedArray
|
||||||
for d in a.shape if type(d) is core.Var}
|
for d in a.shape if type(d) is core.Var}
|
||||||
@ -1589,7 +1588,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
|||||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||||
wrapped_fun = lu.wrap_init(f, params)
|
wrapped_fun = lu.wrap_init(f, params)
|
||||||
|
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# We might be applying this function to arguments with dynamic shapes,
|
# We might be applying this function to arguments with dynamic shapes,
|
||||||
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
|
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
|
||||||
# case, we need to form a jaxpr with leading binders for those axis size
|
# case, we need to form a jaxpr with leading binders for those axis size
|
||||||
|
@ -26,13 +26,13 @@ from weakref import ref
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax._src import linear_util as lu
|
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import effects
|
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
|
from jax._src import effects
|
||||||
|
from jax._src import linear_util as lu
|
||||||
from jax._src import profiler
|
from jax._src import profiler
|
||||||
from jax._src import source_info_util
|
from jax._src import source_info_util
|
||||||
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
|
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
|
||||||
@ -106,7 +106,7 @@ class PartialVal(tuple):
|
|||||||
"""
|
"""
|
||||||
def __new__(cls, xs: tuple[AbstractValue | None, core.Value]):
|
def __new__(cls, xs: tuple[AbstractValue | None, core.Value]):
|
||||||
pv, const = xs
|
pv, const = xs
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
# type checks
|
# type checks
|
||||||
assert isinstance(pv, (AbstractValue, type(None))), xs
|
assert isinstance(pv, (AbstractValue, type(None))), xs
|
||||||
assert (const is None or isinstance(const, core.Tracer) or
|
assert (const is None or isinstance(const, core.Tracer) or
|
||||||
@ -255,7 +255,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
|||||||
# which were unknown to the first call (corresponding to in_avals).
|
# which were unknown to the first call (corresponding to in_avals).
|
||||||
|
|
||||||
# Wrap f to perform the partial evaluation and plumb out aux data.
|
# Wrap f to perform the partial evaluation and plumb out aux data.
|
||||||
if not config.jax_dynamic_shapes:
|
if not config.dynamic_shapes.value:
|
||||||
f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
|
f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
|
||||||
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
|
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
|
||||||
tuple(in_avals))
|
tuple(in_avals))
|
||||||
@ -275,7 +275,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
|||||||
out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)])
|
out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)])
|
||||||
|
|
||||||
# Form the complete list of residuals by forwarding some inputs.
|
# Form the complete list of residuals by forwarding some inputs.
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# With dynamic shapes, we may need to forward implicit arguments.
|
# With dynamic shapes, we may need to forward implicit arguments.
|
||||||
in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns)
|
in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns)
|
||||||
in_consts_full = [None] * len(f.in_type)
|
in_consts_full = [None] * len(f.in_type)
|
||||||
@ -303,7 +303,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
|||||||
staged_params = update_params(staged_params, map(op.not_, in_knowns),
|
staged_params = update_params(staged_params, map(op.not_, in_knowns),
|
||||||
num_new_args)
|
num_new_args)
|
||||||
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
|
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# With dynamic shapes, we may need to substitute Tracers into avals.
|
# With dynamic shapes, we may need to substitute Tracers into avals.
|
||||||
out_tracers = []
|
out_tracers = []
|
||||||
for aval, _ in out_type:
|
for aval, _ in out_type:
|
||||||
@ -958,21 +958,21 @@ def tracers_to_jaxpr(
|
|||||||
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
|
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
|
||||||
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
|
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
|
||||||
outvars, eqns, jaxpr_effects)
|
outvars, eqns, jaxpr_effects)
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
||||||
return jaxpr, const_vals, env_vals
|
return jaxpr, const_vals, env_vals
|
||||||
|
|
||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||||
"""Moves the constvars to the start of invars."""
|
"""Moves the constvars to the start of invars."""
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||||
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
|
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
|
||||||
lifted_jaxpr = Jaxpr(constvars=(),
|
lifted_jaxpr = Jaxpr(constvars=(),
|
||||||
invars=jaxpr.constvars + jaxpr.invars,
|
invars=jaxpr.constvars + jaxpr.invars,
|
||||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||||
effects=jaxpr.effects, debug_info=dbg)
|
effects=jaxpr.effects, debug_info=dbg)
|
||||||
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
|
||||||
return lifted_jaxpr
|
return lifted_jaxpr
|
||||||
|
|
||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
@ -980,24 +980,24 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
|
|||||||
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
|
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
|
||||||
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
|
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
constvars, invars = split_list(jaxpr.invars, [n])
|
constvars, invars = split_list(jaxpr.invars, [n])
|
||||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||||
arg_names=jaxpr.debug_info.arg_names[n:])
|
arg_names=jaxpr.debug_info.arg_names[n:])
|
||||||
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
|
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
|
||||||
debug_info=dbg)
|
debug_info=dbg)
|
||||||
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
|
||||||
return lifted_jaxpr
|
return lifted_jaxpr
|
||||||
|
|
||||||
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
|
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
|
||||||
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
|
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
||||||
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
||||||
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||||
effects=jaxpr.effects)
|
effects=jaxpr.effects)
|
||||||
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
|
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
|
||||||
return converted_jaxpr
|
return converted_jaxpr
|
||||||
|
|
||||||
|
|
||||||
@ -1090,7 +1090,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
|
|||||||
|
|
||||||
# check jaxpr_known and jaxpr_unknown in isolation
|
# check jaxpr_known and jaxpr_unknown in isolation
|
||||||
# TODO(mattjj): enable weak type checking here
|
# TODO(mattjj): enable weak type checking here
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
core.check_jaxpr(jaxpr_known)
|
core.check_jaxpr(jaxpr_known)
|
||||||
core.check_jaxpr(jaxpr_unknown)
|
core.check_jaxpr(jaxpr_unknown)
|
||||||
# check jaxpr_known has input type corresponding to known inputs of jaxpr
|
# check jaxpr_known has input type corresponding to known inputs of jaxpr
|
||||||
@ -1256,7 +1256,7 @@ def _partial_eval_jaxpr_custom_cached(
|
|||||||
known_outvars, known_eqns)
|
known_outvars, known_eqns)
|
||||||
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
|
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
|
||||||
known_eqns, known_effects)
|
known_eqns, known_effects)
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
|
config.enable_checks.value and core.check_jaxpr(jaxpr_known)
|
||||||
|
|
||||||
_, ins_staged = partition_list(in_inst, jaxpr.invars)
|
_, ins_staged = partition_list(in_inst, jaxpr.invars)
|
||||||
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
|
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
|
||||||
@ -1265,7 +1265,7 @@ def _partial_eval_jaxpr_custom_cached(
|
|||||||
outs_staged, staged_eqns)
|
outs_staged, staged_eqns)
|
||||||
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
|
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
|
||||||
outs_staged, staged_eqns, staged_effects)
|
outs_staged, staged_eqns, staged_effects)
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
|
config.enable_checks.value and core.check_jaxpr(jaxpr_staged)
|
||||||
|
|
||||||
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
|
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
|
||||||
len(non_input_res_refs))
|
len(non_input_res_refs))
|
||||||
@ -1486,7 +1486,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
|||||||
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
|
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
|
||||||
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||||
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
|
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
|
||||||
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
|
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||||
|
|
||||||
return new_jaxpr, used_inputs
|
return new_jaxpr, used_inputs
|
||||||
|
|
||||||
@ -1702,7 +1702,7 @@ class JaxprStackFrame:
|
|||||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||||
jaxpr, out_type = _add_implicit_outputs(jaxpr)
|
jaxpr, out_type = _add_implicit_outputs(jaxpr)
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
return jaxpr, out_type, constvals
|
return jaxpr, out_type, constvals
|
||||||
|
|
||||||
def newvar(self, aval):
|
def newvar(self, aval):
|
||||||
@ -2226,7 +2226,7 @@ def trace_to_subjaxpr_dynamic(
|
|||||||
out_tracers = map(trace.full_raise, ans)
|
out_tracers = map(trace.full_raise, ans)
|
||||||
jaxpr, consts = frame.to_jaxpr(out_tracers)
|
jaxpr, consts = frame.to_jaxpr(out_tracers)
|
||||||
del fun, main, trace, frame, in_tracers, out_tracers, ans
|
del fun, main, trace, frame, in_tracers, out_tracers, ans
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts
|
return jaxpr, [v.aval for v in jaxpr.outvars], consts
|
||||||
|
|
||||||
|
|
||||||
@ -2432,7 +2432,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
|
|||||||
|
|
||||||
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
|
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
|
||||||
jaxpr.effects, jaxpr.debug_info)
|
jaxpr.effects, jaxpr.debug_info)
|
||||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||||
return new_jaxpr, out_type
|
return new_jaxpr, out_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,8 +34,9 @@ import jax
|
|||||||
from jax.errors import JAXTypeError
|
from jax.errors import JAXTypeError
|
||||||
|
|
||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
from jax._src import core
|
|
||||||
from jax._src import compiler
|
from jax._src import compiler
|
||||||
|
from jax._src import config
|
||||||
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
@ -51,7 +52,6 @@ from jax._src import tree_util
|
|||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src import xla_bridge as xb
|
from jax._src import xla_bridge as xb
|
||||||
from jax._src.abstract_arrays import array_types
|
from jax._src.abstract_arrays import array_types
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.core import DShapedArray
|
from jax._src.core import DShapedArray
|
||||||
from jax._src.core import ShapedArray
|
from jax._src.core import ShapedArray
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
@ -279,7 +279,7 @@ def xla_pmap_impl_lazy(
|
|||||||
donated_invars: Sequence[bool],
|
donated_invars: Sequence[bool],
|
||||||
is_explicit_global_axis_size: bool,
|
is_explicit_global_axis_size: bool,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
if (config.jax_disable_jit and config.jax_eager_pmap and
|
if (config.disable_jit.value and config.eager_pmap.value and
|
||||||
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
|
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
|
||||||
def _emap_apply_fn(*args):
|
def _emap_apply_fn(*args):
|
||||||
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
||||||
@ -296,7 +296,7 @@ def xla_pmap_impl_lazy(
|
|||||||
is_explicit_global_axis_size, *abstract_args)
|
is_explicit_global_axis_size, *abstract_args)
|
||||||
|
|
||||||
# Don't re-abstractify args unless logging is enabled for performance.
|
# Don't re-abstractify args unless logging is enabled for performance.
|
||||||
if config.jax_distributed_debug:
|
if config.distributed_debug.value:
|
||||||
distributed_debug_log(("Running pmapped function", name),
|
distributed_debug_log(("Running pmapped function", name),
|
||||||
("python function", fun.f),
|
("python function", fun.f),
|
||||||
("devices", devices),
|
("devices", devices),
|
||||||
@ -433,7 +433,7 @@ class MapTrace(core.Trace):
|
|||||||
def process_map(self, map_primitive, fun, tracers, params):
|
def process_map(self, map_primitive, fun, tracers, params):
|
||||||
if params['devices'] is not None:
|
if params['devices'] is not None:
|
||||||
raise ValueError("Nested pmap with explicit devices argument.")
|
raise ValueError("Nested pmap with explicit devices argument.")
|
||||||
if not config.jax_disable_jit:
|
if not config.disable_jit.value:
|
||||||
bind = HashableFunction(
|
bind = HashableFunction(
|
||||||
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
|
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
|
||||||
(map_primitive, fun))
|
(map_primitive, fun))
|
||||||
@ -728,7 +728,7 @@ def lower_parallel_callable(
|
|||||||
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
||||||
f"{replicas.jaxpr_replicas}")
|
f"{replicas.jaxpr_replicas}")
|
||||||
|
|
||||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
||||||
if logger.isEnabledFor(log_priority):
|
if logger.isEnabledFor(log_priority):
|
||||||
logger.log(log_priority,
|
logger.log(log_priority,
|
||||||
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
||||||
@ -1630,7 +1630,7 @@ ShardingInfo = tuple[
|
|||||||
|
|
||||||
|
|
||||||
def _get_default_device() -> xc.Device:
|
def _get_default_device() -> xc.Device:
|
||||||
return config.jax_default_device or xb.local_devices()[0]
|
return config.default_device.value or xb.local_devices()[0]
|
||||||
|
|
||||||
|
|
||||||
class _thread_local_decorator(threading.local):
|
class _thread_local_decorator(threading.local):
|
||||||
@ -1798,7 +1798,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
# code without tuple conversion.
|
# code without tuple conversion.
|
||||||
device_assignment = tuple(da_object)
|
device_assignment = tuple(da_object)
|
||||||
|
|
||||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
||||||
if logger.isEnabledFor(log_priority):
|
if logger.isEnabledFor(log_priority):
|
||||||
logger.log(log_priority,
|
logger.log(log_priority,
|
||||||
"Compiling %s for with global shapes and types %s. "
|
"Compiling %s for with global shapes and types %s. "
|
||||||
@ -2029,7 +2029,7 @@ def lower_sharding_computation(
|
|||||||
transfer_mem_kind_in_jaxpr))
|
transfer_mem_kind_in_jaxpr))
|
||||||
|
|
||||||
if not da_object.is_fully_addressable: # type: ignore
|
if not da_object.is_fully_addressable: # type: ignore
|
||||||
if inline and config.jax_spmd_mode != 'allow_all':
|
if inline and config.spmd_mode.value != 'allow_all':
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Running operations on `Array`s that are not fully addressable by this "
|
"Running operations on `Array`s that are not fully addressable by this "
|
||||||
"process (i.e. `Array`s with data sharded across multiple devices and "
|
"process (i.e. `Array`s with data sharded across multiple devices and "
|
||||||
@ -2116,7 +2116,7 @@ def lower_mesh_computation(
|
|||||||
|
|
||||||
global_axis_sizes = mesh.shape
|
global_axis_sizes = mesh.shape
|
||||||
|
|
||||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
||||||
if logger.isEnabledFor(log_priority):
|
if logger.isEnabledFor(log_priority):
|
||||||
logger.log(log_priority,
|
logger.log(log_priority,
|
||||||
"Compiling %s for %s mesh with global shapes and types %s. "
|
"Compiling %s for %s mesh with global shapes and types %s. "
|
||||||
|
@ -36,6 +36,7 @@ from jax._src import ad_util
|
|||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
from jax._src import array
|
from jax._src import array
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
@ -45,7 +46,6 @@ from jax._src import pretty_printer as pp
|
|||||||
from jax._src import source_info_util
|
from jax._src import source_info_util
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src.abstract_arrays import array_types
|
from jax._src.abstract_arrays import array_types
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
@ -93,7 +93,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
|
|||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
assert shapes
|
assert shapes
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# pass dynamic shapes through unchecked
|
# pass dynamic shapes through unchecked
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@ -182,7 +182,7 @@ def _extract_tracers_dyn_shape(
|
|||||||
shape: Sequence[Union[int, core.Tracer]]
|
shape: Sequence[Union[int, core.Tracer]]
|
||||||
) -> tuple[list[core.Tracer], list[Optional[int]]]:
|
) -> tuple[list[core.Tracer], list[Optional[int]]]:
|
||||||
# Given a sequence representing a shape, pull out Tracers, replacing with None
|
# Given a sequence representing a shape, pull out Tracers, replacing with None
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# We must gate this behavior under a flag because otherwise the errors
|
# We must gate this behavior under a flag because otherwise the errors
|
||||||
# raised are different (and have worse source provenance information).
|
# raised are different (and have worse source provenance information).
|
||||||
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
|
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
|
||||||
@ -794,7 +794,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
|||||||
"""
|
"""
|
||||||
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array):
|
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array):
|
||||||
return type_cast(Array, operand)
|
return type_cast(Array, operand)
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# We must gate this behavior under a flag because otherwise the errors
|
# We must gate this behavior under a flag because otherwise the errors
|
||||||
# raised are different (and have worse source provenance information).
|
# raised are different (and have worse source provenance information).
|
||||||
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
|
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
|
||||||
@ -1649,7 +1649,7 @@ def _unbroadcast(aval, x):
|
|||||||
return _reduce_sum(x, list(range(len(x_shape))))
|
return _reduce_sum(x, list(range(len(x_shape))))
|
||||||
else:
|
else:
|
||||||
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
|
||||||
if config.jax_enable_checks: assert all(aval.shape[i] == 1 for i in dims)
|
if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
|
||||||
return reshape(_reduce_sum(x, dims), aval.shape)
|
return reshape(_reduce_sum(x, dims), aval.shape)
|
||||||
|
|
||||||
def _maybe_broadcast(target_shape, x):
|
def _maybe_broadcast(target_shape, x):
|
||||||
@ -3358,7 +3358,7 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
|
|||||||
msg = 'reshape new_sizes must all be positive, got {}.'
|
msg = 'reshape new_sizes must all be positive, got {}.'
|
||||||
raise TypeError(msg.format(new_sizes))
|
raise TypeError(msg.format(new_sizes))
|
||||||
# TODO(necula): re-enable this check
|
# TODO(necula): re-enable this check
|
||||||
if (not config.jax_dynamic_shapes and
|
if (not config.dynamic_shapes.value and
|
||||||
not math.prod(np.shape(operand)) == math.prod(new_sizes)):
|
not math.prod(np.shape(operand)) == math.prod(new_sizes)):
|
||||||
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
|
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
|
||||||
raise TypeError(msg.format(new_sizes, np.shape(operand)))
|
raise TypeError(msg.format(new_sizes, np.shape(operand)))
|
||||||
@ -4850,7 +4850,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
|||||||
# bool(obj) for an ndarray raises an error, so we check len
|
# bool(obj) for an ndarray raises an error, so we check len
|
||||||
if not len(obj): # pylint: disable=g-explicit-length-test
|
if not len(obj): # pylint: disable=g-explicit-length-test
|
||||||
return
|
return
|
||||||
if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and
|
if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and
|
||||||
any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
|
any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
|
||||||
return # TODO(mattjj): handle more checks in the dynamic shape case
|
return # TODO(mattjj): handle more checks in the dynamic shape case
|
||||||
obj_arr = np.array(obj)
|
obj_arr = np.array(obj)
|
||||||
@ -4913,17 +4913,17 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[Precision
|
|||||||
value to apply to both operands, or as a sequence of two values.
|
value to apply to both operands, or as a sequence of two values.
|
||||||
"""
|
"""
|
||||||
if precision is None:
|
if precision is None:
|
||||||
if config.jax_default_matmul_precision is None:
|
if config.default_matmul_precision.value is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return type_cast(
|
return type_cast(
|
||||||
tuple[PrecisionType, PrecisionType],
|
tuple[PrecisionType, PrecisionType],
|
||||||
(Precision(config.jax_default_matmul_precision),
|
(Precision(config.default_matmul_precision.value),
|
||||||
Precision(config.jax_default_matmul_precision)))
|
Precision(config.default_matmul_precision.value)))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"jax_default_matmul_precision flag must be set to None or a value in "
|
"jax_default_matmul_precision flag must be set to None or a value in "
|
||||||
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}"
|
f"{list(Precision._strings)}, but got {config.default_matmul_precision.value}"
|
||||||
) from None
|
) from None
|
||||||
elif isinstance(precision, str) and precision in Precision._strings:
|
elif isinstance(precision, str) and precision in Precision._strings:
|
||||||
return type_cast(tuple[PrecisionType, PrecisionType],
|
return type_cast(tuple[PrecisionType, PrecisionType],
|
||||||
|
@ -68,12 +68,13 @@ import operator
|
|||||||
from typing import Any, Callable, NamedTuple
|
from typing import Any, Callable, NamedTuple
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from jax._src.tree_util import tree_map
|
from jax._src import config
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
|
from jax._src.tree_util import tree_map
|
||||||
from jax._src.util import curry
|
from jax._src.util import curry
|
||||||
|
|
||||||
|
|
||||||
traceback_util.register_exclusion(__file__)
|
traceback_util.register_exclusion(__file__)
|
||||||
|
|
||||||
|
|
||||||
@ -333,13 +334,13 @@ def cache(call: Callable):
|
|||||||
|
|
||||||
def memoized_fun(fun: WrappedFun, *args):
|
def memoized_fun(fun: WrappedFun, *args):
|
||||||
cache = fun_caches.setdefault(fun.f, {})
|
cache = fun_caches.setdefault(fun.f, {})
|
||||||
if config.jax_check_tracer_leaks:
|
if config.check_tracer_leaks.value:
|
||||||
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
|
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
|
||||||
config.x64_enabled, config.jax_default_device,
|
config.enable_x64.value, config.default_device.value,
|
||||||
config._trace_context())
|
config.config._trace_context())
|
||||||
else:
|
else:
|
||||||
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
|
key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
|
||||||
config.jax_default_device, config._trace_context())
|
config.default_device.value, config.config._trace_context())
|
||||||
result = cache.get(key, None)
|
result = cache.get(key, None)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
ans, stores = result
|
ans, stores = result
|
||||||
|
@ -20,10 +20,10 @@ from typing import Any, Callable, NamedTuple, Optional, TypeVar
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from jax._src import dtypes
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src.config import config
|
from jax._src import dtypes
|
||||||
from jax._src.lax import lax
|
from jax._src.lax import lax
|
||||||
from jax._src.util import safe_zip, safe_map
|
from jax._src.util import safe_zip, safe_map
|
||||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
|
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
|
||||||
@ -161,7 +161,7 @@ def _wraps(
|
|||||||
try:
|
try:
|
||||||
mod = module or fun.__module__
|
mod = module or fun.__module__
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
|
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
|
||||||
else:
|
else:
|
||||||
name = f"{mod}.{name}"
|
name = f"{mod}.{name}"
|
||||||
@ -206,7 +206,7 @@ def _wraps(
|
|||||||
if kept_sections:
|
if kept_sections:
|
||||||
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
|
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
|
||||||
except:
|
except:
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
raise
|
raise
|
||||||
docstr = fun.__doc__
|
docstr = fun.__doc__
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
|||||||
return [lax.asarray(arg) for arg in args]
|
return [lax.asarray(arg) for arg in args]
|
||||||
else:
|
else:
|
||||||
shapes = [np.shape(arg) for arg in args]
|
shapes = [np.shape(arg) for arg in args]
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
# With dynamic shapes we don't support singleton-dimension broadcasting;
|
# With dynamic shapes we don't support singleton-dimension broadcasting;
|
||||||
# we instead broadcast out to the full shape as a temporary workaround.
|
# we instead broadcast out to the full shape as a temporary workaround.
|
||||||
# TODO(mattjj): revise this workaround
|
# TODO(mattjj): revise this workaround
|
||||||
@ -242,7 +242,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
|||||||
if len(nonscalar_ranks) < 2:
|
if len(nonscalar_ranks) < 2:
|
||||||
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
|
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
|
||||||
else:
|
else:
|
||||||
if config.jax_numpy_rank_promotion != "allow":
|
if config.numpy_rank_promotion.value != "allow":
|
||||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||||
result_rank = len(lax.broadcast_shapes(*shapes))
|
result_rank = len(lax.broadcast_shapes(*shapes))
|
||||||
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||||
@ -250,13 +250,13 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
|||||||
|
|
||||||
|
|
||||||
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
|
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
|
||||||
if config.jax_numpy_rank_promotion == "warn":
|
if config.numpy_rank_promotion.value == "warn":
|
||||||
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
||||||
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
||||||
"disable this warning; for more information, see "
|
"disable this warning; for more information, see "
|
||||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||||
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
|
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||||
elif config.jax_numpy_rank_promotion == "raise":
|
elif config.numpy_rank_promotion.value == "raise":
|
||||||
msg = ("Operands could not be broadcast together for {} on shapes {} "
|
msg = ("Operands could not be broadcast together for {} on shapes {} "
|
||||||
"and with the config option jax_numpy_rank_promotion='raise'. "
|
"and with the config option jax_numpy_rank_promotion='raise'. "
|
||||||
"For more information, see "
|
"For more information, see "
|
||||||
|
@ -25,6 +25,7 @@ import warnings
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import stages
|
from jax._src import stages
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
@ -46,7 +47,6 @@ from jax._src.interpreters import partial_eval as pe
|
|||||||
from jax._src.partition_spec import PartitionSpec
|
from jax._src.partition_spec import PartitionSpec
|
||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
|
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
@ -182,7 +182,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
|||||||
@wraps(fun)
|
@wraps(fun)
|
||||||
@api_boundary
|
@api_boundary
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
if config.jax_disable_jit:
|
if config.disable_jit.value:
|
||||||
return fun(*args, **kwargs)
|
return fun(*args, **kwargs)
|
||||||
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
|
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
|||||||
donate_argnums, donate_argnames,
|
donate_argnums, donate_argnames,
|
||||||
static_argnums, static_argnames, device,
|
static_argnums, static_argnames, device,
|
||||||
backend, abstracted_axes):
|
backend, abstracted_axes):
|
||||||
if abstracted_axes and not config.jax_dynamic_shapes:
|
if abstracted_axes and not config.dynamic_shapes.value:
|
||||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||||
|
|
||||||
check_callable(fun)
|
check_callable(fun)
|
||||||
@ -432,7 +432,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
|||||||
dyn_kwargs = {}
|
dyn_kwargs = {}
|
||||||
del kwargs
|
del kwargs
|
||||||
|
|
||||||
if (donate_argnums or donate_argnames) and not config.jax_debug_nans:
|
if (donate_argnums or donate_argnames) and not config.debug_nans.value:
|
||||||
donated_invars = donation_vector(
|
donated_invars = donation_vector(
|
||||||
donate_argnums, donate_argnames, dyn_args, dyn_kwargs)
|
donate_argnums, donate_argnames, dyn_args, dyn_kwargs)
|
||||||
else:
|
else:
|
||||||
@ -462,7 +462,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
|||||||
assert in_shardings is not None or all(i is not None for i in in_shardings)
|
assert in_shardings is not None or all(i is not None for i in in_shardings)
|
||||||
assert out_shardings is not None or all(o is not None for o in out_shardings)
|
assert out_shardings is not None or all(o is not None for o in out_shardings)
|
||||||
|
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
||||||
in_avals = tuple(a for a, e in in_type if e)
|
in_avals = tuple(a for a, e in in_type if e)
|
||||||
else:
|
else:
|
||||||
@ -490,7 +490,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
|||||||
|
|
||||||
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
|
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
|
||||||
|
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
implicit_args = _extract_implicit_args(in_type, explicit_args)
|
implicit_args = _extract_implicit_args(in_type, explicit_args)
|
||||||
else:
|
else:
|
||||||
implicit_args = []
|
implicit_args = []
|
||||||
@ -892,7 +892,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
|||||||
"pjit in_shardings", in_tree, orig_in_shardings,
|
"pjit in_shardings", in_tree, orig_in_shardings,
|
||||||
tupled_args=True)
|
tupled_args=True)
|
||||||
|
|
||||||
if not config.jax_dynamic_shapes:
|
if not config.dynamic_shapes.value:
|
||||||
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
||||||
None if debug_info is None else debug_info.arg_names,
|
None if debug_info is None else debug_info.arg_names,
|
||||||
"pjit arguments", allow_uneven_sharding=False)
|
"pjit arguments", allow_uneven_sharding=False)
|
||||||
@ -909,14 +909,14 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
|
|||||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
|
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec",
|
||||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||||
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
||||||
if config.jax_dynamic_shapes:
|
if config.dynamic_shapes.value:
|
||||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||||
lu.annotate(fun, in_type), debug_info=pe_debug)
|
lu.annotate(fun, in_type), debug_info=pe_debug)
|
||||||
else:
|
else:
|
||||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||||
fun, in_type, debug_info=pe_debug)
|
fun, in_type, debug_info=pe_debug)
|
||||||
|
|
||||||
if not config.jax_dynamic_shapes:
|
if not config.dynamic_shapes.value:
|
||||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||||
|
|
||||||
if any(isinstance(c, core.Tracer) for c in consts):
|
if any(isinstance(c, core.Tracer) for c in consts):
|
||||||
@ -944,7 +944,7 @@ def _check_and_canonicalize_out_shardings(
|
|||||||
"pjit out_shardings", out_tree(), orig_out_shardings,
|
"pjit out_shardings", out_tree(), orig_out_shardings,
|
||||||
tupled_args=False)
|
tupled_args=False)
|
||||||
|
|
||||||
if not config.jax_dynamic_shapes:
|
if not config.dynamic_shapes.value:
|
||||||
pjit_check_aval_sharding(
|
pjit_check_aval_sharding(
|
||||||
out_shardings_flat, out_type,
|
out_shardings_flat, out_type,
|
||||||
None if debug_info is None else debug_info.result_paths,
|
None if debug_info is None else debug_info.result_paths,
|
||||||
@ -1132,10 +1132,10 @@ def _pjit_call_impl_python(
|
|||||||
lowering_parameters=mlir.LoweringParameters()).compile()
|
lowering_parameters=mlir.LoweringParameters()).compile()
|
||||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||||
# This check is expensive so only do it if enable_checks is on.
|
# This check is expensive so only do it if enable_checks is on.
|
||||||
if compiled._auto_spmd_lowering and config.jax_enable_checks:
|
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||||
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
|
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
|
||||||
jaxpr.jaxpr.debug_info)
|
jaxpr.jaxpr.debug_info)
|
||||||
if config.jax_distributed_debug:
|
if config.distributed_debug.value:
|
||||||
# Defensively only perform fingerprint logic if debug logging is enabled
|
# Defensively only perform fingerprint logic if debug logging is enabled
|
||||||
# NOTE(skyewm): I didn't benchmark this
|
# NOTE(skyewm): I didn't benchmark this
|
||||||
fingerprint = None
|
fingerprint = None
|
||||||
@ -1151,7 +1151,7 @@ def _pjit_call_impl_python(
|
|||||||
try:
|
try:
|
||||||
return compiled.unsafe_call(*args), compiled
|
return compiled.unsafe_call(*args), compiled
|
||||||
except FloatingPointError:
|
except FloatingPointError:
|
||||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
|
||||||
|
|
||||||
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
|
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
|
||||||
|
|
||||||
@ -1159,7 +1159,7 @@ def _pjit_call_impl_python(
|
|||||||
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
|
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
|
||||||
msg = ("An invalid value was encountered in the output of the "
|
msg = ("An invalid value was encountered in the output of the "
|
||||||
f"`jit`-decorated function {name}. Because "
|
f"`jit`-decorated function {name}. Because "
|
||||||
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
|
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
|
||||||
"de-optimized function (i.e., the function as if the `jit` "
|
"de-optimized function (i.e., the function as if the `jit` "
|
||||||
"decorator were removed) was called in an attempt to get a more "
|
"decorator were removed) was called in an attempt to get a more "
|
||||||
"precise error message. However, the de-optimized function did not "
|
"precise error message. However, the de-optimized function did not "
|
||||||
@ -1313,7 +1313,7 @@ def pjit_staging_rule(trace, *args, **params):
|
|||||||
jaxpr = params['jaxpr']
|
jaxpr = params['jaxpr']
|
||||||
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
|
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
|
||||||
propagate_source_info=False)
|
propagate_source_info=False)
|
||||||
elif config.jax_dynamic_shapes:
|
elif config.dynamic_shapes.value:
|
||||||
source_info = source_info_util.current()
|
source_info = source_info_util.current()
|
||||||
out_tracers = []
|
out_tracers = []
|
||||||
for aval in _out_type(params['jaxpr']):
|
for aval in _out_type(params['jaxpr']):
|
||||||
|
@ -31,7 +31,7 @@ from jax import tree_util
|
|||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import basearray
|
from jax._src import basearray
|
||||||
from jax._src import config as config_lib
|
from jax._src import config as config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
@ -40,7 +40,6 @@ from jax._src import sharding_specs
|
|||||||
from jax._src import tree_util as tree_util_internal
|
from jax._src import tree_util as tree_util_internal
|
||||||
from jax._src import typing
|
from jax._src import typing
|
||||||
from jax._src.api import jit, vmap
|
from jax._src.api import jit, vmap
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.dtypes import float0
|
from jax._src.dtypes import float0
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
@ -969,7 +968,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array:
|
|||||||
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
||||||
k1 = convert(
|
k1 = convert(
|
||||||
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
|
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
|
||||||
with config_lib.numpy_dtype_promotion('standard'):
|
with config.numpy_dtype_promotion('standard'):
|
||||||
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
|
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
|
||||||
# inputs. We should avoid this.
|
# inputs. We should avoid this.
|
||||||
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
|
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
|
||||||
@ -1270,7 +1269,7 @@ def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
|
|||||||
|
|
||||||
@partial(jit, static_argnums=(1,))
|
@partial(jit, static_argnums=(1,))
|
||||||
def _threefry_split(key, shape) -> typing.Array:
|
def _threefry_split(key, shape) -> typing.Array:
|
||||||
if config.jax_threefry_partitionable:
|
if config.threefry_partitionable.value:
|
||||||
return _threefry_split_foldlike(key, shape) # type: ignore
|
return _threefry_split_foldlike(key, shape) # type: ignore
|
||||||
else:
|
else:
|
||||||
return _threefry_split_original(key, shape) # type: ignore
|
return _threefry_split_original(key, shape) # type: ignore
|
||||||
@ -1305,7 +1304,7 @@ def threefry_random_bits(key: typing.Array, bit_width, shape):
|
|||||||
if bit_width not in (8, 16, 32, 64):
|
if bit_width not in (8, 16, 32, 64):
|
||||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||||
|
|
||||||
if config.jax_threefry_partitionable:
|
if config.threefry_partitionable.value:
|
||||||
return _threefry_random_bits_partitionable(key, bit_width, shape)
|
return _threefry_random_bits_partitionable(key, bit_width, shape)
|
||||||
else:
|
else:
|
||||||
return _threefry_random_bits_original(key, bit_width, shape)
|
return _threefry_random_bits_original(key, bit_width, shape)
|
||||||
@ -1398,7 +1397,7 @@ def _rbg_seed(seed: typing.Array) -> typing.Array:
|
|||||||
return jnp.concatenate([halfkey, halfkey])
|
return jnp.concatenate([halfkey, halfkey])
|
||||||
|
|
||||||
def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
|
def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
|
||||||
if config.jax_threefry_partitionable:
|
if config.threefry_partitionable.value:
|
||||||
_threefry_split = _threefry_split_foldlike
|
_threefry_split = _threefry_split_foldlike
|
||||||
else:
|
else:
|
||||||
_threefry_split = _threefry_split_original
|
_threefry_split = _threefry_split_original
|
||||||
|
@ -16,8 +16,8 @@ from functools import partial
|
|||||||
import operator
|
import operator
|
||||||
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
|
from jax._src import config
|
||||||
from jax._src import dtypes as _dtypes
|
from jax._src import dtypes as _dtypes
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.tree_util import tree_map, tree_reduce
|
from jax._src.tree_util import tree_map, tree_reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -130,7 +130,7 @@ def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
|
|||||||
|
|
||||||
def _check_dtypes_match(xs, ys):
|
def _check_dtypes_match(xs, ys):
|
||||||
def _assert_dtypes_match(x, y):
|
def _assert_dtypes_match(x, y):
|
||||||
if config.x64_enabled:
|
if config.enable_x64.value:
|
||||||
assert _dtype(x) == _dtype(y)
|
assert _dtype(x) == _dtype(y)
|
||||||
else:
|
else:
|
||||||
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
|
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
|
||||||
|
@ -26,13 +26,12 @@ import jax.numpy as jnp
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.numpy.linalg import cholesky, svd, eigh
|
from jax.numpy.linalg import cholesky, svd, eigh
|
||||||
|
|
||||||
from jax._src import config as config_lib
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import prng
|
from jax._src import prng
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.api import jit, vmap
|
from jax._src.api import jit, vmap
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.core import NamedShape
|
from jax._src.core import NamedShape
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
@ -77,17 +76,17 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
|
|||||||
elif _arraylike(key):
|
elif _arraylike(key):
|
||||||
# Call random_wrap here to surface errors for invalid keys.
|
# Call random_wrap here to surface errors for invalid keys.
|
||||||
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
|
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
|
||||||
if config.jax_legacy_prng_key == 'error':
|
if config.legacy_prng_key.value == 'error':
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Legacy uint32 key array passed as key to jax.random function. '
|
'Legacy uint32 key array passed as key to jax.random function. '
|
||||||
'Please create keys using jax.random.key(). If use of a raw key array '
|
'Please create keys using jax.random.key(). If use of a raw key array '
|
||||||
'was intended, set jax_legacy_prng_key="allow".')
|
'was intended, set jax_legacy_prng_key="allow".')
|
||||||
elif config.jax_legacy_prng_key == 'warn':
|
elif config.legacy_prng_key.value == 'warn':
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Legacy uint32 key array passed as key to jax.random function. '
|
'Legacy uint32 key array passed as key to jax.random function. '
|
||||||
'Please create keys using jax.random.key(). If use of a raw key array '
|
'Please create keys using jax.random.key(). If use of a raw key array '
|
||||||
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
|
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
|
||||||
elif config.jax_enable_custom_prng:
|
elif config.enable_custom_prng.value:
|
||||||
# TODO(jakevdp): possibly remove this warning condition.
|
# TODO(jakevdp): possibly remove this warning condition.
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||||
@ -101,7 +100,7 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
|
|||||||
def _return_prng_keys(was_wrapped, key):
|
def _return_prng_keys(was_wrapped, key):
|
||||||
# TODO(frostig): remove once we always enable_custom_prng
|
# TODO(frostig): remove once we always enable_custom_prng
|
||||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||||
if config.jax_enable_custom_prng:
|
if config.enable_custom_prng.value:
|
||||||
return key
|
return key
|
||||||
else:
|
else:
|
||||||
return prng.random_unwrap(key) if was_wrapped else key
|
return prng.random_unwrap(key) if was_wrapped else key
|
||||||
@ -120,7 +119,7 @@ def default_prng_impl():
|
|||||||
The default implementation is determined by ``config.jax_default_prng_impl``,
|
The default implementation is determined by ``config.jax_default_prng_impl``,
|
||||||
which specifies it by name.
|
which specifies it by name.
|
||||||
"""
|
"""
|
||||||
impl_name = config.jax_default_prng_impl
|
impl_name = config.default_prng_impl.value
|
||||||
assert impl_name in prng.prngs, impl_name
|
assert impl_name in prng.prngs, impl_name
|
||||||
return prng.prngs[impl_name]
|
return prng.prngs[impl_name]
|
||||||
|
|
||||||
@ -225,8 +224,8 @@ def PRNGKey(seed: Union[int, Array], *,
|
|||||||
# TODO(frostig): remove once we always enable_custom_prng
|
# TODO(frostig): remove once we always enable_custom_prng
|
||||||
def _check_default_impl_with_no_custom_prng(impl, name):
|
def _check_default_impl_with_no_custom_prng(impl, name):
|
||||||
default_impl = default_prng_impl()
|
default_impl = default_prng_impl()
|
||||||
default_name = config.jax_default_prng_impl
|
default_name = config.default_prng_impl.value
|
||||||
if not config.jax_enable_custom_prng and default_impl is not impl:
|
if not config.enable_custom_prng.value and default_impl is not impl:
|
||||||
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
||||||
f'to seed an RNG with an implementation "{name}" '
|
f'to seed an RNG with an implementation "{name}" '
|
||||||
f'differing from the default "{default_name}".')
|
f'differing from the default "{default_name}".')
|
||||||
@ -829,7 +828,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
|
|||||||
else: # 'cholesky'
|
else: # 'cholesky'
|
||||||
factor = cholesky(cov)
|
factor = cholesky(cov)
|
||||||
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
|
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
|
||||||
with config_lib.numpy_rank_promotion('allow'):
|
with config.numpy_rank_promotion('allow'):
|
||||||
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -24,11 +24,11 @@ import numpy as np
|
|||||||
|
|
||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
from jax._src import source_info_util
|
from jax._src import source_info_util
|
||||||
from jax._src import tree_util
|
from jax._src import tree_util
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
@ -262,7 +262,7 @@ run_state_p.multiple_results = True
|
|||||||
|
|
||||||
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
|
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
|
||||||
which_linear: tuple[bool, ...]):
|
which_linear: tuple[bool, ...]):
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
core.check_jaxpr(jaxpr)
|
core.check_jaxpr(jaxpr)
|
||||||
assert len(jaxpr.invars) == len(args)
|
assert len(jaxpr.invars) == len(args)
|
||||||
assert len(which_linear) == len(args)
|
assert len(which_linear) == len(args)
|
||||||
|
@ -27,7 +27,7 @@ from typing import Any, Callable
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import core
|
from jax import core
|
||||||
from jax._src.config import config
|
from jax._src import config
|
||||||
from jax._src.lib import tpu_mosaic
|
from jax._src.lib import tpu_mosaic
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
@ -38,7 +38,7 @@ from jaxlib.mlir.dialects import stablehlo
|
|||||||
from jaxlib.mlir.passmanager import PassManager
|
from jaxlib.mlir.passmanager import PassManager
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
config.define_bool_state(
|
mosaic_use_cpp_passes = config.config.define_bool_state(
|
||||||
name="mosaic_use_cpp_passes",
|
name="mosaic_use_cpp_passes",
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help=(
|
||||||
@ -54,13 +54,13 @@ tpu = tpu_mosaic.tpu
|
|||||||
apply_vector_layout = tpu_mosaic.apply_vector_layout
|
apply_vector_layout = tpu_mosaic.apply_vector_layout
|
||||||
infer_memref_layout = tpu_mosaic.infer_memref_layout
|
infer_memref_layout = tpu_mosaic.infer_memref_layout
|
||||||
|
|
||||||
config.define_bool_state(
|
mosaic_allow_hlo = config.config.define_bool_state(
|
||||||
name="jax_mosaic_allow_hlo",
|
name="jax_mosaic_allow_hlo",
|
||||||
default=False,
|
default=False,
|
||||||
help="Allow hlo dialects in Mosaic",
|
help="Allow hlo dialects in Mosaic",
|
||||||
)
|
)
|
||||||
|
|
||||||
config.define_bool_state(
|
mosaic_dump_mlir = config.config.define_bool_state(
|
||||||
name="jax_mosaic_dump_mlir",
|
name="jax_mosaic_dump_mlir",
|
||||||
default=False,
|
default=False,
|
||||||
help="Print mlir module after each pass",
|
help="Print mlir module after each pass",
|
||||||
@ -243,7 +243,7 @@ def _lower_tpu_kernel(
|
|||||||
)
|
)
|
||||||
dump_mlir(module, "initial module")
|
dump_mlir(module, "initial module")
|
||||||
|
|
||||||
if config.jax_mosaic_allow_hlo:
|
if mosaic_allow_hlo.value:
|
||||||
# Run hlo dialect conversion: hlo -> linalg -> vector.
|
# Run hlo dialect conversion: hlo -> linalg -> vector.
|
||||||
pipeline = [
|
pipeline = [
|
||||||
"hlo-legalize-to-arithmetic",
|
"hlo-legalize-to-arithmetic",
|
||||||
@ -255,7 +255,7 @@ def _lower_tpu_kernel(
|
|||||||
)
|
)
|
||||||
dump_mlir(module, "after hlo conversion module")
|
dump_mlir(module, "after hlo conversion module")
|
||||||
|
|
||||||
if config.mosaic_use_cpp_passes:
|
if mosaic_use_cpp_passes.value:
|
||||||
pipeline = [
|
pipeline = [
|
||||||
(
|
(
|
||||||
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
|
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
|
||||||
@ -278,7 +278,7 @@ def _lower_tpu_kernel(
|
|||||||
module.operation.verify()
|
module.operation.verify()
|
||||||
dump_mlir(module, "after infer vector layout pass")
|
dump_mlir(module, "after infer vector layout pass")
|
||||||
|
|
||||||
if config.mosaic_use_cpp_passes:
|
if mosaic_use_cpp_passes.value:
|
||||||
pipeline = [
|
pipeline = [
|
||||||
(
|
(
|
||||||
"func.func(tpu-apply-vector-layout{sublane-count=8"
|
"func.func(tpu-apply-vector-layout{sublane-count=8"
|
||||||
@ -418,6 +418,6 @@ def _lowered_as_tpu_kernel(
|
|||||||
|
|
||||||
def dump_mlir(module: ir.Module, msg: str):
|
def dump_mlir(module: ir.Module, msg: str):
|
||||||
"""A helper function to print mlir module with a message."""
|
"""A helper function to print mlir module with a message."""
|
||||||
if config.jax_mosaic_dump_mlir:
|
if mosaic_dump_mlir.value:
|
||||||
print(f"[jax_mosaic_dump_mlir] {msg}")
|
print(f"[jax_mosaic_dump_mlir] {msg}")
|
||||||
print(module)
|
print(module)
|
||||||
|
@ -19,9 +19,9 @@ import traceback
|
|||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Optional, TypeVar, cast
|
from typing import Any, Callable, Optional, TypeVar, cast
|
||||||
|
|
||||||
from jax._src.config import config
|
from jax._src import config
|
||||||
from jax._src.lib import xla_extension
|
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
|
from jax._src.lib import xla_extension
|
||||||
|
|
||||||
|
|
||||||
C = TypeVar("C", bound=Callable[..., Any])
|
C = TypeVar("C", bound=Callable[..., Any])
|
||||||
@ -139,7 +139,7 @@ def _ipython_supports_tracebackhide() -> bool:
|
|||||||
return IPython.version_info[:2] >= (7, 17)
|
return IPython.version_info[:2] >= (7, 17)
|
||||||
|
|
||||||
def _filtering_mode() -> str:
|
def _filtering_mode() -> str:
|
||||||
mode = config.jax_traceback_filtering
|
mode = config.traceback_filtering.value
|
||||||
if mode is None or mode == "auto":
|
if mode is None or mode == "auto":
|
||||||
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
|
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
|
||||||
mode = "tracebackhide"
|
mode = "tracebackhide"
|
||||||
|
@ -23,9 +23,9 @@ import weakref
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from jax._src import config
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import utils as jaxlib_utils
|
from jax._src.lib import utils as jaxlib_utils
|
||||||
from jax._src.config import config
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -257,10 +257,10 @@ def cache(max_size=4096):
|
|||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if config.jax_check_tracer_leaks:
|
if config.check_tracer_leaks.value:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return cached(config._trace_context(), *args, **kwargs)
|
return cached(config.config._trace_context(), *args, **kwargs)
|
||||||
|
|
||||||
wrapper.cache_clear = cached.cache_clear
|
wrapper.cache_clear = cached.cache_clear
|
||||||
wrapper.cache_info = cached.cache_info
|
wrapper.cache_info = cached.cache_info
|
||||||
@ -278,7 +278,7 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
|
|||||||
behave similar to `functools.lru_cache`.
|
behave similar to `functools.lru_cache`.
|
||||||
"""
|
"""
|
||||||
global _weakref_lru_caches
|
global _weakref_lru_caches
|
||||||
cached_call = xc.weakref_lru_cache(config._trace_context, call, maxsize)
|
cached_call = xc.weakref_lru_cache(config.config._trace_context, call, maxsize)
|
||||||
_weakref_lru_caches.add(cached_call)
|
_weakref_lru_caches.add(cached_call)
|
||||||
return cached_call
|
return cached_call
|
||||||
|
|
||||||
@ -464,7 +464,7 @@ def distributed_debug_log(*pairs):
|
|||||||
pairs: A sequence of label/value pairs to log. The first pair is treated as
|
pairs: A sequence of label/value pairs to log. The first pair is treated as
|
||||||
a heading for subsequent pairs.
|
a heading for subsequent pairs.
|
||||||
"""
|
"""
|
||||||
if config.jax_distributed_debug:
|
if config.distributed_debug.value:
|
||||||
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
|
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
|
||||||
try:
|
try:
|
||||||
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")
|
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")
|
||||||
|
@ -35,9 +35,8 @@ import threading
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from jax._src import config
|
||||||
from jax._src import distributed
|
from jax._src import distributed
|
||||||
from jax._src import config as jax_config
|
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.lib import cuda_versions
|
from jax._src.lib import cuda_versions
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax._src.lib import xla_extension
|
from jax._src.lib import xla_extension
|
||||||
@ -63,34 +62,34 @@ XlaBackend = xla_client.Client
|
|||||||
|
|
||||||
|
|
||||||
# TODO(phawkins): Remove jax_xla_backend.
|
# TODO(phawkins): Remove jax_xla_backend.
|
||||||
_XLA_BACKEND = jax_config.DEFINE_string(
|
_XLA_BACKEND = config.DEFINE_string(
|
||||||
'jax_xla_backend', '',
|
'jax_xla_backend', '',
|
||||||
'Deprecated, please use --jax_platforms instead.')
|
'Deprecated, please use --jax_platforms instead.')
|
||||||
BACKEND_TARGET = jax_config.DEFINE_string(
|
BACKEND_TARGET = config.DEFINE_string(
|
||||||
'jax_backend_target',
|
'jax_backend_target',
|
||||||
os.getenv('JAX_BACKEND_TARGET', '').lower(),
|
os.getenv('JAX_BACKEND_TARGET', '').lower(),
|
||||||
'Either "local" or "rpc:address" to connect to a remote service target.')
|
'Either "local" or "rpc:address" to connect to a remote service target.')
|
||||||
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
|
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
|
||||||
_PLATFORM_NAME = jax_config.DEFINE_string(
|
_PLATFORM_NAME = config.DEFINE_string(
|
||||||
'jax_platform_name',
|
'jax_platform_name',
|
||||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||||
'Deprecated, please use --jax_platforms instead.')
|
'Deprecated, please use --jax_platforms instead.')
|
||||||
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
CUDA_VISIBLE_DEVICES = config.DEFINE_string(
|
||||||
'jax_cuda_visible_devices', 'all',
|
'jax_cuda_visible_devices', 'all',
|
||||||
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
||||||
'comma-separate list of integer device IDs.')
|
'comma-separate list of integer device IDs.')
|
||||||
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
_ROCM_VISIBLE_DEVICES = config.DEFINE_string(
|
||||||
'jax_rocm_visible_devices', 'all',
|
'jax_rocm_visible_devices', 'all',
|
||||||
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
||||||
'comma-separate list of integer device IDs.')
|
'comma-separate list of integer device IDs.')
|
||||||
|
|
||||||
_USE_MOCK_GPU_CLIENT = jax_config.DEFINE_bool(
|
_USE_MOCK_GPU_CLIENT = config.DEFINE_bool(
|
||||||
name="use_mock_gpu_client",
|
name="use_mock_gpu_client",
|
||||||
default=False,
|
default=False,
|
||||||
help="If True, use a mock GPU client instead of a real one.",
|
help="If True, use a mock GPU client instead of a real one.",
|
||||||
)
|
)
|
||||||
|
|
||||||
_MOCK_NUM_GPUS = jax_config.DEFINE_integer(
|
_MOCK_NUM_GPUS = config.DEFINE_integer(
|
||||||
name="mock_num_gpus",
|
name="mock_num_gpus",
|
||||||
default=1,
|
default=1,
|
||||||
help="Mock GPU client number of gpus.",
|
help="Mock GPU client number of gpus.",
|
||||||
@ -223,7 +222,7 @@ def _check_cuda_versions():
|
|||||||
|
|
||||||
|
|
||||||
def make_gpu_client(
|
def make_gpu_client(
|
||||||
*, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
|
*, platform_name: str, visible_devices_flag: config.FlagHolder[str]
|
||||||
) -> xla_client.Client:
|
) -> xla_client.Client:
|
||||||
visible_devices = visible_devices_flag.value
|
visible_devices = visible_devices_flag.value
|
||||||
allowed_devices = None
|
allowed_devices = None
|
||||||
@ -564,11 +563,10 @@ def backends() -> dict[str, xla_client.Client]:
|
|||||||
with _backend_lock:
|
with _backend_lock:
|
||||||
if _backends:
|
if _backends:
|
||||||
return _backends
|
return _backends
|
||||||
if config.jax_platforms:
|
if jax_platforms := config.jax_platforms.value:
|
||||||
jax_platforms = config.jax_platforms.split(",")
|
|
||||||
platforms = []
|
platforms = []
|
||||||
# Allow platform aliases in the list of platforms.
|
# Allow platform aliases in the list of platforms.
|
||||||
for platform in jax_platforms:
|
for platform in jax_platforms.split(","):
|
||||||
platforms.extend(expand_platform_alias(platform))
|
platforms.extend(expand_platform_alias(platform))
|
||||||
priorities = range(len(platforms), 0, -1)
|
priorities = range(len(platforms), 0, -1)
|
||||||
# If the user specified a list of platforms explicitly, always fail
|
# If the user specified a list of platforms explicitly, always fail
|
||||||
@ -597,14 +595,14 @@ def backends() -> dict[str, xla_client.Client]:
|
|||||||
_backend_errors[platform] = str(err)
|
_backend_errors[platform] = str(err)
|
||||||
logger.info(err_msg)
|
logger.info(err_msg)
|
||||||
else:
|
else:
|
||||||
if config.jax_platforms:
|
if config.jax_platforms.value:
|
||||||
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
|
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
|
||||||
else:
|
else:
|
||||||
err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
|
err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
|
||||||
raise RuntimeError(err_msg)
|
raise RuntimeError(err_msg)
|
||||||
|
|
||||||
assert _default_backend is not None
|
assert _default_backend is not None
|
||||||
if not config.jax_platforms:
|
if not config.jax_platforms.value:
|
||||||
_suggest_missing_backends()
|
_suggest_missing_backends()
|
||||||
return _backends
|
return _backends
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ import numpy as np
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src.custom_derivatives import lift_jvp
|
from jax._src.custom_derivatives import lift_jvp
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
@ -67,7 +68,6 @@ from jax._src.lib import pytree
|
|||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||||
from jax.util import safe_map, safe_zip, split_list
|
from jax.util import safe_map, safe_zip, split_list
|
||||||
from jax._src.config import config
|
|
||||||
from jax._src.lax.control_flow import _check_tree_and_avals
|
from jax._src.lax.control_flow import _check_tree_and_avals
|
||||||
from jax._src.numpy import lax_numpy
|
from jax._src.numpy import lax_numpy
|
||||||
from jax.experimental import sparse
|
from jax.experimental import sparse
|
||||||
@ -614,7 +614,7 @@ def _add_sparse(spenv, *spvalues):
|
|||||||
raise NotImplementedError("Addition between sparse matrices of different shapes.")
|
raise NotImplementedError("Addition between sparse matrices of different shapes.")
|
||||||
if X.indices_ref == Y.indices_ref:
|
if X.indices_ref == Y.indices_ref:
|
||||||
out_data = lax.add(spenv.data(X), spenv.data(Y))
|
out_data = lax.add(spenv.data(X), spenv.data(Y))
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
assert X.indices_sorted == Y.indices_sorted
|
assert X.indices_sorted == Y.indices_sorted
|
||||||
assert X.unique_indices == Y.unique_indices
|
assert X.unique_indices == Y.unique_indices
|
||||||
out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref,
|
out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref,
|
||||||
@ -657,7 +657,7 @@ def _mul_sparse(spenv, *spvalues):
|
|||||||
X, Y = spvalues
|
X, Y = spvalues
|
||||||
if X.is_sparse() and Y.is_sparse():
|
if X.is_sparse() and Y.is_sparse():
|
||||||
if X.indices_ref == Y.indices_ref and X.unique_indices:
|
if X.indices_ref == Y.indices_ref and X.unique_indices:
|
||||||
if config.jax_enable_checks:
|
if config.enable_checks.value:
|
||||||
assert X.indices_sorted == Y.indices_sorted
|
assert X.indices_sorted == Y.indices_sorted
|
||||||
assert X.unique_indices == Y.unique_indices
|
assert X.unique_indices == Y.unique_indices
|
||||||
out_data = lax.mul(spenv.data(X), spenv.data(Y))
|
out_data = lax.mul(spenv.data(X), spenv.data(Y))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user