diff --git a/jax/_src/api.py b/jax/_src/api.py index 6f631367a..62347fe1f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -28,9 +28,8 @@ from functools import partial import inspect import math import typing -from typing import (Any, Callable, Literal, - NamedTuple, Optional, TypeVar, Union, - overload, cast) +from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, + cast) import weakref import numpy as np @@ -42,6 +41,7 @@ from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, prefix_errors, generate_key_paths) +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import effects @@ -53,7 +53,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import pjit from jax._src import xla_bridge as xb -from jax._src.core import eval_jaxpr +from jax._src.core import eval_jaxpr, ShapedArray from jax._src.api_util import ( flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, argnums_partial_except, flatten_axes, donation_vector, @@ -73,23 +73,12 @@ from jax._src import tree_util from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps from jax._src import util - -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import mlir -from jax._src.interpreters import xla - -from jax._src.config import ( - config, - disable_jit as _disable_jit, - debug_nans as config_debug_nans, - debug_infs as config_debug_infs, - _thread_local_state as config_thread_local_state, - explicit_device_put_scope as config_explicit_device_put_scope, - explicit_device_get_scope as config_explicit_device_get_scope) -from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla +from jax._src.interpreters import xla traceback_util.register_exclusion(__file__) @@ -121,27 +110,28 @@ def _nan_check_posthook(fun, args, kwargs, output): dispatch.check_special(pjit.pjit_p.name, buffers) except FloatingPointError: # compiled_fun can only raise in this case - assert config.jax_debug_nans or config.jax_debug_infs + assert config.debug_nans.value or config.debug_infs.value print("Invalid nan value encountered in the output of a C++-jit/pmap " "function. Calling the de-optimized version.") fun._cache_miss(*args, **kwargs)[0] # probably won't return def _update_debug_special_global(_): - if config._read("jax_debug_nans") or config._read("jax_debug_infs"): + if (config.config._read("jax_debug_nans") or + config.config._read("jax_debug_infs")): jax_jit.global_state().post_hook = _nan_check_posthook else: jax_jit.global_state().post_hook = None def _update_debug_special_thread_local(_): - if (getattr(config_thread_local_state, "jax_debug_nans", False) or - getattr(config_thread_local_state, "jax_debug_infs", False)): + if (getattr(config._thread_local_state, "jax_debug_nans", False) or + getattr(config._thread_local_state, "jax_debug_infs", False)): jax_jit.thread_local_state().post_hook = _nan_check_posthook else: jax_jit.thread_local_state().post_hook = None -config_debug_nans._add_hooks(_update_debug_special_global, +config.debug_nans._add_hooks(_update_debug_special_global, _update_debug_special_thread_local) -config_debug_infs._add_hooks(_update_debug_special_global, +config.debug_infs._add_hooks(_update_debug_special_global, _update_debug_special_thread_local) @@ -376,7 +366,7 @@ def disable_jit(disable: bool = True): Value of y is [2 4 6] [5 7 9] """ - with _disable_jit(disable): + with config.disable_jit(disable): yield @@ -1579,7 +1569,7 @@ def pmap( # TODO(yashkatariya): Move this out after shard_map is out of experimental and # in _src - if config.jax_pmap_shmap_merge: + if config.pmap_shmap_merge.value: from jax.experimental.shard_map import pmap return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, @@ -1670,7 +1660,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, dyn_args, dyn_in_axes = args, in_axes args, in_tree = tree_flatten((dyn_args, kwargs)) - if donate_tuple and not config.jax_debug_nans: + if donate_tuple and not config.debug_nans.value: donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs) else: donated_invars = (False,) * len(args) @@ -2530,7 +2520,7 @@ def device_put( This function is always asynchronous, i.e. returns immediately without blocking the calling Python thread until any transfers are completed. """ - with config_explicit_device_put_scope(): + with config.explicit_device_put_scope(): if ((device is None or isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and (src is None or @@ -2624,7 +2614,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices)) - with config_explicit_device_put_scope(): + with config.explicit_device_put_scope(): return tree_map(_device_put_sharded, *shards) @@ -2674,7 +2664,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) - with config_explicit_device_put_scope(): + with config.explicit_device_put_scope(): return tree_map(_device_put_replicated, x) @@ -2720,7 +2710,7 @@ def device_get(x: Any): - device_put_sharded - device_put_replicated """ - with config_explicit_device_get_scope(): + with config.explicit_device_get_scope(): for y in tree_leaves(x): try: y.copy_to_host_async() diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ed896dd5c..b692eff37 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -20,9 +20,9 @@ from functools import partial from typing import Any, Callable, Optional, Union import jax +from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax import config from jax.tree_util import (tree_flatten, tree_unflatten, register_pytree_node, Partial) from jax._src import core @@ -200,7 +200,7 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, ct = jax.lax.psum(ct, axis_name=axes_to_reduce) ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct # TODO(mattjj): add back these checks for dynamic shapes - # if config.jax_enable_checks: + # if config.enable_checks.value: # ct_aval = core.get_aval(ct_env[v]) # joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape() # assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval) @@ -458,7 +458,7 @@ class JVPTracer(Tracer): __slots__ = ['primal', 'tangent'] def __init__(self, trace, primal, tangent): - if config.jax_enable_checks: + if config.enable_checks.value: _primal_tangent_shapes_match(primal, tangent) self._trace = trace self.primal = primal @@ -624,7 +624,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) - if config.jax_dynamic_shapes: + if config.dynamic_shapes.value: # TODO(mattjj,dougalm): handle consts, for now assume just args which_lin = [is_undefined_primal(x) for x in args] res_invars, _ = partition_list(which_lin, call_jaxpr.invars) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 843f29fe3..37d449be8 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Union import numpy as np import jax -from jax import config +from jax._src import config from jax._src import core from jax._src import source_info_util from jax._src import linear_util as lu @@ -318,7 +318,7 @@ class BatchTracer(Tracer): def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, source_info: source_info_util.SourceInfo | None = None): - if config.jax_enable_checks: + if config.enable_checks.value: assert type(batch_dim) in (NotMapped, int, RaggedAxis) if type(batch_dim) is int: aval = raise_to_shaped(core.get_aval(val)) @@ -416,7 +416,7 @@ class BatchTrace(Trace): return frame def process_primitive(self, primitive, tracers, params): - if config.jax_dynamic_shapes: + if config.dynamic_shapes.value: primitive.abstract_eval(*(t.aval for t in tracers), **params) vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) is_axis_primitive = primitive in axis_primitive_batchers diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index b54dcd4eb..c1a8b4de9 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -22,9 +22,9 @@ import itertools import operator from typing import Callable -from jax import config from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -129,7 +129,7 @@ def switch(index, branches: Sequence[Callable], *operands, hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) - if (config.jax_disable_jit and + if (config.disable_jit.value and isinstance(core.get_aval(index), ConcreteArray)): return branches[int(index)](*operands) @@ -226,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, msg = ("Pred type must be either boolean or number, got {}.") raise TypeError(msg.format(pred_dtype)) - if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray): + if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray): if pred: return true_fun(*operands) else: @@ -806,7 +806,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear): return jaxpr0.out_avals, joined_effects def cond_bind(*args, branches, linear): - if config.jax_enable_checks: + if config.enable_checks.value: avals = map(core.get_aval, args) in_atoms = [core.Var(0, '', a) for a in avals] # dummies _cond_typecheck(True, *in_atoms, branches=branches, linear=linear) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 4a6e12d6e..160bae300 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -22,9 +22,9 @@ from typing import Any, Callable, Optional, TypeVar import jax import weakref +from jax._src import config from jax._src import core from jax._src import linear_util as lu -from jax import config # type: ignore[no-redef] from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, tree_map, tree_flatten_with_path, keystr) @@ -214,7 +214,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], else: length, = unique_lengths - if config.jax_disable_jit: + if config.disable_jit.value: if length == 0: raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") carry = init @@ -859,7 +859,7 @@ def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn used_carry_out = _map(operator.or_, used_carry_out, used_carry_in) else: assert False, "Fixpoint not reached" - if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr) + if config.enable_checks.value: core.check_jaxpr(jaxpr.jaxpr) new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] new_params = dict(eqn.params, num_consts=sum(used_consts), @@ -1105,7 +1105,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, return new_invals, [*carry_out, *ys_out] def scan_bind(*args, **params): - if config.jax_enable_checks: + if config.enable_checks.value: avals = _map(core.get_aval, args) in_atoms = [core.Var(0, '', a) for a in avals] # dummies _scan_typecheck(True, *in_atoms, **params) @@ -1190,7 +1190,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], """ if not (callable(body_fun) and callable(cond_fun)): raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.") - if config.jax_disable_jit: + if config.disable_jit.value: try: val = init_val while cond_fun(val): @@ -1931,7 +1931,7 @@ def fori_loop(lower, upper, body_fun, init_val): use_scan = False if use_scan: - if config.jax_disable_jit and upper_ == lower_: + if config.disable_jit.value and upper_ == lower_: # non-jit implementation of scan does not support length=0 return init_val diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index e9d6caa9c..2e3ac7005 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -25,6 +25,7 @@ import numpy as np import jax from jax._src import ad_util +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -156,7 +157,7 @@ def dynamic_slice( - :func:`jax.lax.dynamic_index_in_dim` """ start_indices = _dynamic_slice_indices(operand, start_indices) - if jax.config.jax_dynamic_shapes: + if config.dynamic_shapes.value: dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes) else: dynamic_sizes = [] @@ -1090,7 +1091,7 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) - if not jax.config.jax_dynamic_shapes: + if not config.dynamic_shapes.value: if not all(map(operator.ge, limit_indices, start_indices)): msg = ("slice limit_indices must be greater than or equal to start_indices," " got start_indices {} and limit_indices {}.") diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 62b87412e..eab9ab393 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -24,6 +24,7 @@ import jax import jax.numpy as jnp from jax import custom_jvp from jax import lax +from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import util @@ -446,7 +447,7 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ - if jax.config.jax_softmax_custom_jvp: + if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition # of `_softmax` and incorrectly concludes that `_softmax` returns # `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`. diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d45f7a9cc..7db0d87c5 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -47,6 +47,7 @@ from jax import lax from jax.tree_util import tree_leaves, tree_flatten, tree_map from jax._src import api_util +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -56,7 +57,6 @@ from jax._src.core import ShapedArray, ConcreteArray from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) from jax._src.lax import lax as lax_internal -from jax._src.lib import xla_client from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util @@ -2321,7 +2321,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: def arange(start: DimSize, stop: Optional[DimSize] = None, step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") - if not jax.config.jax_dynamic_shapes: + if not config.dynamic_shapes.value: util.check_arraylike("arange", start) if stop is None and step is None: start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'") @@ -3459,7 +3459,7 @@ def _einsum( # NOTE(mattjj): this can fail non-deterministically in python3, maybe # due to opt_einsum - assert jax.config.jax_dynamic_shapes or all( + assert config.dynamic_shapes.value or all( name in lhs_names and name in rhs_names and lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)] for name in contracted_names), ( @@ -4309,7 +4309,7 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, return result # TODO(mattjj,dougalm): expand dynamic shape indexing support - if jax.config.jax_dynamic_shapes and arr.ndim > 0: + if config.dynamic_shapes.value and arr.ndim > 0: try: aval = core.get_aval(idx) except: pass else: diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3b411ab10..c470d1da2 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -37,6 +37,7 @@ from typing import Any, NamedTuple, Protocol, Union import jax from jax._src import core +from jax._src import config from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util @@ -488,7 +489,7 @@ class Compiled(Stage): # which might conflict here. params = args[0] args = args[1:] - if jax.config.jax_dynamic_shapes: + if config.dynamic_shapes.value: raise NotImplementedError if params.no_kwargs and kwargs: kws = ', '.join(kwargs.keys()) diff --git a/jax/experimental/jax2tf/examples/saved_model_main_test.py b/jax/experimental/jax2tf/examples/saved_model_main_test.py index f5c30c566..3aafc834e 100644 --- a/jax/experimental/jax2tf/examples/saved_model_main_test.py +++ b/jax/experimental/jax2tf/examples/saved_model_main_test.py @@ -17,13 +17,14 @@ import os from absl import flags from absl.testing import absltest from absl.testing import parameterized +import jax +from jax._src import config from jax._src import test_util as jtu -from jax import config from jax.experimental.jax2tf.examples import saved_model_main from jax.experimental.jax2tf.tests import tf_test_util -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS @@ -48,8 +49,8 @@ class SavedModelMainTest(tf_test_util.JaxToTfTestCase): model="mnist_flax", serving_batch_size=-1): if (serving_batch_size == -1 and - config.jax2tf_default_native_serialization and - not config.jax_dynamic_shapes): + config.jax2tf_default_native_serialization.value and + not config.dynamic_shapes.value): self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") FLAGS.model = model FLAGS.model_classifier_layer = True diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8a7d65004..9187c8db9 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -29,7 +29,6 @@ import numpy as np import jax from jax import lax -from jax import config from jax import custom_derivatives from jax import random from jax import numpy as jnp @@ -45,6 +44,7 @@ from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api from jax._src import api_util +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -322,7 +322,7 @@ def convert(fun_jax: Callable, if not enable_xla: native_serialization = False else: - native_serialization = config.jax2tf_default_native_serialization + native_serialization = config.jax2tf_default_native_serialization.value if native_serialization and not enable_xla: raise ValueError( @@ -1180,7 +1180,7 @@ class TensorFlowTracer(core.Tracer): if isinstance(val, (tf.Tensor, tf.Variable)): val_shape = val.shape - if config.jax_enable_checks: + if config.enable_checks.value: assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than {val_shape=}" # To compare types, we must handle float0 in JAX and x64 in TF if phys_aval.dtype == dtypes.float0: @@ -1335,7 +1335,7 @@ class TensorFlowTrace(core.Trace): # Check that the impl rule returned a value of expected shape and dtype # TODO: adapt this to match polymorphic shapes - if config.jax_enable_checks: + if config.enable_checks.value: if primitive.multiple_results: for o, expected_aval in zip(out, out_aval): # type: ignore assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), ( @@ -2538,7 +2538,7 @@ tf_impl_with_avals[lax.reduce_p] = _reduce def _cumred(lax_reduce_fn: Callable, lax_reduce_window_fn: Callable, extra_name_stack: str): - if config.jax2tf_associative_scan_reductions: + if config.jax2tf_associative_scan_reductions.value: return _convert_jax_impl(partial(lax_control_flow.associative_scan, lax_reduce_fn), multiple_results=False, diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index de15803b0..48ef2d86c 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -48,12 +48,12 @@ from absl import testing import numpy as np import jax -from jax import config from jax import dtypes from jax import lax from jax import numpy as jnp from jax._src import ad_util +from jax._src import config from jax._src import dispatch from jax._src import prng from jax._src import test_util as jtu @@ -67,7 +67,7 @@ from jax._src import random as jax_random # then the test file has to import jtu first (to define the flags) which is not # desired if the test file is outside of this project (we don't want a # dependency on jtu outside of jax repo). -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() FLAGS = config.FLAGS Rng = Any # A random number generator @@ -2676,7 +2676,7 @@ for dtype in (np.float32, np.float64): def wrap_and_split(): key = jax.random.key(42) - if jax.config.jax_enable_custom_prng: + if config.enable_custom_prng.value: key = prng.random_wrap(key, impl=jax.random.default_prng_impl()) result = jax.random.split(key, 2) return prng.random_unwrap(result) @@ -3314,7 +3314,7 @@ for padding, lhs_dilation, rhs_dilation in [ rhs_dilation=rhs_dilation) key_types = [((4,), np.uint32)] -if config.jax_enable_x64: +if config.enable_x64.value: key_types.append(((2,), np.uint64)) for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY, diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 5880b3175..981fdb817 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -63,8 +63,8 @@ from absl.testing import parameterized import jax from jax import dtypes from jax import numpy as jnp +from jax._src import config from jax._src import test_util as jtu -from jax import config from jax.experimental import jax2tf from jax.interpreters import mlir from jax._src.interpreters import xla @@ -72,7 +72,7 @@ from jax._src.interpreters import xla import numpy as np import tensorflow as tf # type: ignore[import] -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Import after parsing flags from jax.experimental.jax2tf.tests import tf_test_util @@ -114,7 +114,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) - if config.jax2tf_default_native_serialization and not enable_xla: + if config.jax2tf_default_native_serialization.value and not enable_xla: raise unittest.SkipTest("native_serialization not supported with enable_xla=False") if ("eigh" == harness.group_name and @@ -122,7 +122,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): device == "tpu"): raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32") - if (config.jax2tf_default_native_serialization and + if (config.jax2tf_default_native_serialization.value and device == "gpu" and "lu" in harness.fullname): raise unittest.SkipTest("b/269388847: lu failures on GPU") @@ -130,7 +130,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): def skipCustomCallTest(target: str): raise unittest.SkipTest( f"TODO(b/272239584): custom call target not guaranteed stable: {target}") - if config.jax2tf_default_native_serialization: + if config.jax2tf_default_native_serialization.value: if device == "gpu": if "custom_linear_solve_" in harness.fullname: skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched") @@ -146,7 +146,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): enable_xla=enable_xla) except Exception as e: # TODO(b/264596006): custom calls are not registered properly with TF in OSS - if (config.jax2tf_default_native_serialization and + if (config.jax2tf_default_native_serialization.value and "does not work with custom calls" in str(e)): logging.warning("Suppressing error %s", e) raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF") @@ -257,7 +257,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # The CPU has more supported types, and harnesses self.assertEqual("cpu", jtu.device_under_test()) self.assertTrue( - config.x64_enabled, + config.enable_x64.value, "Documentation generation must be run with JAX_ENABLE_X64=1") with open( @@ -299,7 +299,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): y = np.int32(3) self.ConvertAndCompare(jnp.floor_divide, x, y) expected = jnp.floor_divide(x, y) - if not config.jax2tf_default_native_serialization: + if not config.jax2tf_default_native_serialization.value: # With native serialization TF1 seems to want to run the converted code # on the CPU even when the default backend is the TPU. # Try it with TF 1 as well (#5831) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 4efeb220c..09a6dcf7d 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -30,8 +30,8 @@ import unittest from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu -from jax import config from jax import lax from jax.experimental import jax2tf from jax.experimental import pjit @@ -47,7 +47,7 @@ import numpy as np import tensorflow as tf # type: ignore[import] -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util @@ -225,7 +225,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): # Annotation count for the input count_in_P = 1 if in_shardings == "P" else 0 - if config.jax2tf_default_native_serialization: + if config.jax2tf_default_native_serialization.value: # With native serialization even unspecified in_shardings turn into replicated count_in_replicated = 1 if in_shardings in [None, "missing"] else 0 else: @@ -400,14 +400,14 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): # Annotation count for the primal input and the grad output count_in_P = self.GEQ(2) if in_shardings == "P" else 0 - if config.jax2tf_default_native_serialization: + if config.jax2tf_default_native_serialization.value: # With native serialization even unspecified shardings turn into replicated count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0 else: count_in_replicated = self.GEQ(2) if in_shardings is None else 0 # Annotation count for the contangent input count_out_P = self.GEQ(1) if out_shardings == "P" else 0 - if config.jax2tf_default_native_serialization: + if config.jax2tf_default_native_serialization.value: # With native serialization even unspecified shardings turn into replicated count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0 else: @@ -479,7 +479,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): "nested_pjit_sharded", "nested_pjit_replicated") ]) def test_pjit_eager_error(self, func="pjit_sharded"): - if config.jax2tf_default_native_serialization: + if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("There is no error in eager mode for native serialization") # Define some test functions diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 82c40fca7..a54c142c2 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -29,9 +29,9 @@ from jax import numpy as jnp from jax._src import test_util as jtu from jax import tree_util -from jax import config from jax.experimental import jax2tf from jax.experimental.export import export +from jax._src import config from jax._src import xla_bridge import numpy as np import tensorflow as tf # type: ignore[import] @@ -177,16 +177,14 @@ class JaxToTfTestCase(jtu.JaxTestCase): # We run the tests using the maximum version supported, even though # the default serialization version may be held back for a while to # ensure compatibility - version = config.jax_serialization_version - self.addCleanup(functools.partial(config.update, - "jax_serialization_version", version)) + version = config.jax_serialization_version.value if self.use_max_serialization_version: # Use the largest supported by both export and tfxla.call_module version = min(export.maximum_supported_serialization_version, tfxla.call_module_maximum_supported_version()) self.assertGreaterEqual(version, export.minimum_supported_serialization_version) - config.update("jax_serialization_version", version) + self.enter_context(config.jax_serialization_version(version)) logging.info( "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, @@ -203,7 +201,7 @@ class JaxToTfTestCase(jtu.JaxTestCase): def to_numpy_dtype(dt): return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype - if not config.x64_enabled and canonicalize_dtypes: + if not config.enable_x64.value and canonicalize_dtypes: self.assertEqual( dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))), dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y)))) @@ -410,7 +408,7 @@ class JaxToTfTestCase(jtu.JaxTestCase): # graph. We count the number of characters in the textual representation # of the constant. f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def() - if config.jax2tf_default_native_serialization: + if config.jax2tf_default_native_serialization.value: # This way of finding constants may be brittle, if the constant representation # contains >. It seems tobe hex-encoded, so this may be safe. large_consts = [m for m in re.findall(r"dense<([^>]+)>", str(f_tf_graph)) if len(m) >= at_least] diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 2099ffb72..3082b2a13 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -26,12 +26,12 @@ import numpy as np import jax from jax import numpy as jnp +from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() FLAGS = config.FLAGS numpy_version = jtu.numpy_version() @@ -705,7 +705,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - @unittest.skipIf(not config.jax_enable_x64, "test requires X64") + @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): # Regression test for https://github.com/google/jax/issues/8513 diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 0d0c951b0..9017bdd56 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -27,11 +27,11 @@ from jax import lax from jax.tree_util import register_pytree_node_class import jax._src.scipy.sparse.linalg as sp_linalg +from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_types = jtu.dtypes.floating @@ -100,7 +100,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): preconditioner=[None, 'identity', 'exact', 'random'], ) def test_cg_against_scipy(self, shape, dtype, preconditioner): - if not config.x64_enabled: + if not config.enable_x64.value: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) @@ -221,7 +221,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): ) def test_bicgstab_against_scipy( self, shape, dtype, preconditioner): - if not config.jax_enable_x64: + if not config.enable_x64.value: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) @@ -326,7 +326,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): ) def test_gmres_against_scipy( self, shape, dtype, preconditioner, solve_method): - if not config.x64_enabled: + if not config.enable_x64.value: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) @@ -437,7 +437,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): """ The Arnoldi decomposition within GMRES is correct. """ - if not config.x64_enabled: + if not config.enable_x64.value: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) diff --git a/tests/nn_test.py b/tests/nn_test.py index 089473e61..8d140dd91 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import scipy.stats +from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src import ad_checkpoint @@ -33,8 +34,7 @@ from jax import random import jax import jax.numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class NNFunctionsTest(jtu.JaxTestCase): @@ -143,7 +143,7 @@ class NNFunctionsTest(jtu.JaxTestCase): # TODO(mattjj): include log_softmax in these extra tests if/when we add a # custom_jvp rule for it (since otherwise it doesn't pass the numerical # checks below). - if fn is nn.softmax and config.jax_softmax_custom_jvp: + if fn is nn.softmax and config.softmax_custom_jvp.value: g_fun = lambda x: jnp.take(fn(x, where=m, initial=-jnp.inf), jnp.array([0, 2, 3])) jtu.check_grads(g_fun, (x,), order=2) @@ -153,7 +153,7 @@ class NNFunctionsTest(jtu.JaxTestCase): jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3) def testSoftmaxGradResiduals(self): - if not jax.config.jax_softmax_custom_jvp: + if not config.softmax_custom_jvp.value: raise unittest.SkipTest("only applies when upgrade flag enabled") x = jnp.array([5.5, 1.3, -4.2, 0.9]) res = ad_checkpoint.saved_residuals(nn.softmax, x) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e0810e97e..21636c4f0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -30,6 +30,7 @@ import concurrent.futures import jax import jax.numpy as jnp from jax._src import core +from jax._src import config from jax._src import test_util as jtu from jax import dtypes from jax import stages @@ -59,8 +60,7 @@ from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2, safe_zip -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None @@ -955,7 +955,7 @@ class PJitTest(jtu.BufferDonationTestCase): @jtu.with_mesh([('x', 2)]) def testWithCustomPRNGKey(self): - if not config.jax_enable_custom_prng: + if not config.enable_custom_prng.value: raise unittest.SkipTest("test requires jax_enable_custom_prng") key = prng.seed_with_impl(prng.rbg_prng_impl, 87) # Make sure this doesn't crash diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 63e6dd7df..036b99cf9 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -22,8 +22,8 @@ import unittest from absl.testing import absltest from absl.testing import parameterized import jax -from jax import config from jax import lax +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import test_util as jtu @@ -40,7 +40,7 @@ import jax.numpy as jnp from jax.sharding import Mesh import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _format_multiline(text): @@ -197,7 +197,7 @@ class PythonCallbackTest(jtu.JaxTestCase): @with_pure_and_io_callbacks def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback): - if config.jax_enable_x64: + if config.enable_x64.value: raise unittest.SkipTest("Test only needed when 64-bit mode disabled.") @jax.jit diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 4450092dc..795bab09a 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -29,6 +29,7 @@ from jax import grad from jax import lax from jax import numpy as jnp from jax import random +from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import test_util as jtu @@ -36,8 +37,7 @@ from jax import vmap from jax._src import prng as prng_internal -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.all_floating complex_dtypes = jtu.dtypes.complex @@ -63,7 +63,7 @@ class LaxRandomTest(jtu.JaxTestCase): fail_prob = 0.003 if samples.dtype == jnp.bfloat16 else 0.01 # TODO(frostig): This reads enable_custom_prng as a proxy for # whether RBG keys may be involved, but that's no longer exact. - if config.jax_enable_custom_prng and samples.dtype == jnp.bfloat16: + if config.enable_custom_prng.value and samples.dtype == jnp.bfloat16: return self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob) @@ -382,7 +382,7 @@ class LaxRandomTest(jtu.JaxTestCase): dtype=[np.float64], # NOTE: KS test fails with float32 ) def testBeta(self, a, b, dtype): - if not config.x64_enabled: + if not config.enable_x64.value: raise SkipTest("skip test except on X64") key = self.make_key(0) rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype) @@ -824,7 +824,7 @@ class LaxRandomTest(jtu.JaxTestCase): assert np.unique(keys, axis=0).shape[0] == 2 def testStaticShapeErrors(self): - if config.jax_disable_jit: + if config.disable_jit.value: raise SkipTest("test only relevant when jit enabled") @jax.jit @@ -991,7 +991,7 @@ class LaxRandomTest(jtu.JaxTestCase): def test_prng_jit_invariance(self, seed, type_): if type_ == "int" and seed == (1 << 64) - 1: self.skipTest("Expected failure: Python int too large.") - if not config.x64_enabled and seed > np.iinfo(np.int32).max: + if not config.enable_x64.value and seed > np.iinfo(np.int32).max: self.skipTest("Expected failure: Python int too large.") type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] args_maker = lambda: [type_(seed)] diff --git a/tests/random_test.py b/tests/random_test.py index 8d58bdd5b..d320a714a 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -30,6 +30,7 @@ from jax import lax from jax import numpy as jnp from jax import random from jax import tree_util +from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import test_util as jtu @@ -39,8 +40,7 @@ from jax.interpreters import xla from jax._src import random as jax_random from jax._src import prng as prng_internal -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() PRNG_IMPLS = list(prng_internal.prngs.items()) @@ -247,7 +247,7 @@ class PrngTest(jtu.JaxTestCase): finally: xla.apply_primitive = apply_primitive - @skipIf(config.jax_threefry_partitionable, 'changed random bit values') + @skipIf(config.threefry_partitionable.value, 'changed random bit values') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def testRngRandomBits(self, make_key): # Test specific outputs to ensure consistent random values between JAX versions. @@ -280,7 +280,7 @@ class PrngTest(jtu.JaxTestCase): with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): bits64 = random_bits(key, 64, (3,)) - if config.x64_enabled: + if config.enable_x64.value: expected64 = np.array([3982329540505020460, 16822122385914693683, 7882654074788531506], dtype=np.uint64) else: @@ -315,11 +315,11 @@ class PrngTest(jtu.JaxTestCase): with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): bits64 = random_bits(key, 64, (3,)) - expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32') + expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32') self.assertEqual(bits64.shape, (3,)) self.assertEqual(bits64.dtype, expected_dtype) - @skipIf(config.jax_threefry_partitionable, 'changed random bit values') + @skipIf(config.threefry_partitionable.value, 'changed random bit values') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def testRngRandomBitsViewProperty(self, make_key): # TODO: add 64-bit if it ever supports this property. @@ -338,7 +338,7 @@ class PrngTest(jtu.JaxTestCase): @jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS) - @skipIf(config.jax_threefry_partitionable, 'changed random bit values') + @skipIf(config.threefry_partitionable.value, 'changed random bit values') @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testRandomDistributionValues(self, case, make_key): """ @@ -355,9 +355,9 @@ class PrngTest(jtu.JaxTestCase): * Considering adding a flag that reverts the new behavior, made available for a deprecation window's amount of time. """ - if config.x64_enabled and case.on_x64 == OnX64.SKIP: + if config.enable_x64.value: self.skipTest("test produces different values when jax_enable_x64=True") - if not config.x64_enabled and case.on_x64 == OnX64.ONLY: + if not config.enable_x64.value: self.skipTest("test only valid when jax_enable_x64=True") with jax.default_prng_impl(case.prng_impl): func = getattr(random, case.name) @@ -368,7 +368,7 @@ class PrngTest(jtu.JaxTestCase): actual = func(key, **case.params, shape=case.shape) self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol) - @skipIf(config.jax_threefry_partitionable, 'changed random bit values') + @skipIf(config.threefry_partitionable.value, 'changed random bit values') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def testPRNGValues(self, make_key): # Test to ensure consistent random values between JAX versions @@ -376,7 +376,7 @@ class PrngTest(jtu.JaxTestCase): self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype, dtypes.canonicalize_dtype(jnp.int_)) - if config.x64_enabled: + if config.enable_x64.value: self.assertAllClose( random.randint(k, (3, 3), 0, 8, dtype='int64'), np.array([[7, 2, 6], @@ -407,7 +407,7 @@ class PrngTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, msg): random.bits(make_key(0), (3, 4), np.dtype('float16')) - @skipIf(not config.jax_threefry_partitionable, 'enable after upgrade') + @skipIf(not config.threefry_partitionable.value, 'enable after upgrade') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_threefry_split_fold_in_symmetry(self, make_key): with jax.default_prng_impl('threefry2x32'): @@ -420,7 +420,7 @@ class PrngTest(jtu.JaxTestCase): self.assertArraysEqual(f2, s2) self.assertArraysEqual(f3, s3) - @skipIf(not config.jax_threefry_partitionable, 'enable after upgrade') + @skipIf(not config.threefry_partitionable.value, 'enable after upgrade') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_threefry_split_vmapped_fold_in_symmetry(self, make_key): # See https://github.com/google/jax/issues/7708 @@ -435,7 +435,7 @@ class PrngTest(jtu.JaxTestCase): self.assertArraysEqual(f2, s2) self.assertArraysEqual(f3, s3) - @skipIf(config.jax_threefry_partitionable, 'changed random bit values') + @skipIf(config.threefry_partitionable.value, 'changed random bit values') def test_loggamma_nan_corner_case(self): # regression test for https://github.com/google/jax/issues/17922 # This particular key previously led to NaN output. @@ -459,20 +459,20 @@ class PrngTest(jtu.JaxTestCase): {"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]}, {"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]}, {"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]}, - {"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]}, - {"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]}, + {"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]}, + {"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]}, {"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]}, {"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]}, - {"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]}, - {"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]}, + {"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]}, + {"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]}, {"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]}, {"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]}, {"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]}, {"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]}, - {"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]}, - {"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]}, - {"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]}, - {"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]}, + {"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]}, + {"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]}, + {"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]}, + {"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]}, ] for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS] ]) @@ -482,7 +482,7 @@ class PrngTest(jtu.JaxTestCase): maker = lambda k: random.key_data(jax.jit(make_key)(k)) else: maker = lambda k: random.key_data(make_key(k)) - if (jit and typ is int and not config.x64_enabled and + if (jit and typ is int and not config.enable_x64.value and (seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)): # We expect an error to be raised. # NOTE: we check 'if jit' because some people rely on builtin int seeds @@ -607,7 +607,7 @@ class KeyArrayTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, "Cannot interpret"): jnp.issubdtype(key, dtypes.prng_key) - @skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag') + @skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag') def test_construction_upgrade_flag(self): key = random.PRNGKey(42) self.assertIsInstance(key, jax_random.PRNGKeyArray) @@ -1059,7 +1059,7 @@ class KeyArrayTest(jtu.JaxTestCase): self.assertArraysEqual(jax.random.key_data(key1), jax.random.key_data(key2)) - impl = config.jax_default_prng_impl + impl = config.default_prng_impl.value key3 = jax.random.wrap_key_data(data, impl=impl) self.assertEqual(key1.dtype, key3.dtype) self.assertArraysEqual(jax.random.key_data(key1), diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 28520aaf1..90c3b6440 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -28,9 +28,9 @@ import numpy as np import jax from jax import lax -from jax import config from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src import xla_bridge @@ -43,7 +43,7 @@ import jax.numpy as jnp from jax.experimental.shard_map import shard_map -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -589,7 +589,7 @@ class ShardMapTest(jtu.JaxTestCase): return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16, dtype=jnp.int32) - pspec = P('x') if config.jax_enable_custom_prng else P('x', None) + pspec = P('x') if config.enable_custom_prng.value else P('x', None) g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec) _ = g(sharded_rng) # don't crash! diff --git a/tests/state_test.py b/tests/state_test.py index 25c441565..d28349658 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -24,8 +24,8 @@ import jax from jax import random from jax import lax from jax._src import core +from jax._src import config from jax._src import linear_util as lu -from jax import config from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu from jax._src.util import tuple_insert @@ -48,7 +48,7 @@ from jax._src.state.primitives import (get_p, swap_p, addupdate_p, from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect, AccumEffect, AbstractRef) -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class StatePrimitivesTest(jtu.JaxTestCase): @@ -413,9 +413,9 @@ class StatePrimitivesTest(jtu.JaxTestCase): def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims, idx_bdims, out_bdim, op): - float_ = (jnp.dtype('float64') if jax.config.jax_enable_x64 else + float_ = (jnp.dtype('float64') if config.enable_x64.value else jnp.dtype('float32')) - int_ = (jnp.dtype('int64') if jax.config.jax_enable_x64 else + int_ = (jnp.dtype('int64') if config.enable_x64.value else jnp.dtype('int32')) axis_size = 7 out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b)