Migrate another subset of internal modules to use state objects

The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
This commit is contained in:
Sergei Lebedev 2023-10-11 08:45:30 -07:00 committed by jax authors
parent 04422b8aa4
commit 2f70ae700a
24 changed files with 139 additions and 147 deletions

View File

@ -28,9 +28,8 @@ from functools import partial
import inspect
import math
import typing
from typing import (Any, Callable, Literal,
NamedTuple, Optional, TypeVar, Union,
overload, cast)
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
cast)
import weakref
import numpy as np
@ -42,6 +41,7 @@ from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
prefix_errors, generate_key_paths)
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import effects
@ -53,7 +53,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr
from jax._src.core import eval_jaxpr, ShapedArray
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
@ -73,23 +73,12 @@ from jax._src import tree_util
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.config import (
config,
disable_jit as _disable_jit,
debug_nans as config_debug_nans,
debug_infs as config_debug_infs,
_thread_local_state as config_thread_local_state,
explicit_device_put_scope as config_explicit_device_put_scope,
explicit_device_get_scope as config_explicit_device_get_scope)
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
traceback_util.register_exclusion(__file__)
@ -121,27 +110,28 @@ def _nan_check_posthook(fun, args, kwargs, output):
dispatch.check_special(pjit.pjit_p.name, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs
assert config.debug_nans.value or config.debug_infs.value
print("Invalid nan value encountered in the output of a C++-jit/pmap "
"function. Calling the de-optimized version.")
fun._cache_miss(*args, **kwargs)[0] # probably won't return
def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
if (config.config._read("jax_debug_nans") or
config.config._read("jax_debug_infs")):
jax_jit.global_state().post_hook = _nan_check_posthook
else:
jax_jit.global_state().post_hook = None
def _update_debug_special_thread_local(_):
if (getattr(config_thread_local_state, "jax_debug_nans", False) or
getattr(config_thread_local_state, "jax_debug_infs", False)):
if (getattr(config._thread_local_state, "jax_debug_nans", False) or
getattr(config._thread_local_state, "jax_debug_infs", False)):
jax_jit.thread_local_state().post_hook = _nan_check_posthook
else:
jax_jit.thread_local_state().post_hook = None
config_debug_nans._add_hooks(_update_debug_special_global,
config.debug_nans._add_hooks(_update_debug_special_global,
_update_debug_special_thread_local)
config_debug_infs._add_hooks(_update_debug_special_global,
config.debug_infs._add_hooks(_update_debug_special_global,
_update_debug_special_thread_local)
@ -376,7 +366,7 @@ def disable_jit(disable: bool = True):
Value of y is [2 4 6]
[5 7 9]
"""
with _disable_jit(disable):
with config.disable_jit(disable):
yield
@ -1579,7 +1569,7 @@ def pmap(
# TODO(yashkatariya): Move this out after shard_map is out of experimental and
# in _src
if config.jax_pmap_shmap_merge:
if config.pmap_shmap_merge.value:
from jax.experimental.shard_map import pmap
return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes,
static_broadcasted_argnums=static_broadcasted_argnums,
@ -1670,7 +1660,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
dyn_args, dyn_in_axes = args, in_axes
args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple and not config.jax_debug_nans:
if donate_tuple and not config.debug_nans.value:
donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs)
else:
donated_invars = (False,) * len(args)
@ -2530,7 +2520,7 @@ def device_put(
This function is always asynchronous, i.e. returns immediately without
blocking the calling Python thread until any transfers are completed.
"""
with config_explicit_device_put_scope():
with config.explicit_device_put_scope():
if ((device is None or
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
(src is None or
@ -2624,7 +2614,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices))
with config_explicit_device_put_scope():
with config.explicit_device_put_scope():
return tree_map(_device_put_sharded, *shards)
@ -2674,7 +2664,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
assert len(xla.aval_to_xla_shapes(aval)) == 1
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)
with config_explicit_device_put_scope():
with config.explicit_device_put_scope():
return tree_map(_device_put_replicated, x)
@ -2720,7 +2710,7 @@ def device_get(x: Any):
- device_put_sharded
- device_put_replicated
"""
with config_explicit_device_get_scope():
with config.explicit_device_get_scope():
for y in tree_leaves(x):
try:
y.copy_to_host_async()

View File

@ -20,9 +20,9 @@ from functools import partial
from typing import Any, Callable, Optional, Union
import jax
from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax import config
from jax.tree_util import (tree_flatten, tree_unflatten,
register_pytree_node, Partial)
from jax._src import core
@ -200,7 +200,7 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
# TODO(mattjj): add back these checks for dynamic shapes
# if config.jax_enable_checks:
# if config.enable_checks.value:
# ct_aval = core.get_aval(ct_env[v])
# joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
# assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)
@ -458,7 +458,7 @@ class JVPTracer(Tracer):
__slots__ = ['primal', 'tangent']
def __init__(self, trace, primal, tangent):
if config.jax_enable_checks:
if config.enable_checks.value:
_primal_tangent_shapes_match(primal, tangent)
self._trace = trace
self.primal = primal
@ -624,7 +624,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
if update_params:
params = update_params(params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct])
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
# TODO(mattjj,dougalm): handle consts, for now assume just args
which_lin = [is_undefined_primal(x) for x in args]
res_invars, _ = partition_list(which_lin, call_jaxpr.invars)

View File

@ -22,7 +22,7 @@ from typing import Any, Callable, Union
import numpy as np
import jax
from jax import config
from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import linear_util as lu
@ -318,7 +318,7 @@ class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
source_info: source_info_util.SourceInfo | None = None):
if config.jax_enable_checks:
if config.enable_checks.value:
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
if type(batch_dim) is int:
aval = raise_to_shaped(core.get_aval(val))
@ -416,7 +416,7 @@ class BatchTrace(Trace):
return frame
def process_primitive(self, primitive, tracers, params):
if config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
primitive.abstract_eval(*(t.aval for t in tracers), **params)
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
is_axis_primitive = primitive in axis_primitive_batchers

View File

@ -22,9 +22,9 @@ import itertools
import operator
from typing import Callable
from jax import config
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -129,7 +129,7 @@ def switch(index, branches: Sequence[Callable], *operands,
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)
if (config.jax_disable_jit and
if (config.disable_jit.value and
isinstance(core.get_aval(index), ConcreteArray)):
return branches[int(index)](*operands)
@ -226,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray):
if pred:
return true_fun(*operands)
else:
@ -806,7 +806,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear):
return jaxpr0.out_avals, joined_effects
def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
if config.enable_checks.value:
avals = map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_cond_typecheck(True, *in_atoms, branches=branches, linear=linear)

View File

@ -22,9 +22,9 @@ from typing import Any, Callable, Optional, TypeVar
import jax
import weakref
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax import config # type: ignore[no-redef]
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr)
@ -214,7 +214,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
else:
length, = unique_lengths
if config.jax_disable_jit:
if config.disable_jit.value:
if length == 0:
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
carry = init
@ -859,7 +859,7 @@ def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
if config.enable_checks.value: core.check_jaxpr(jaxpr.jaxpr)
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
@ -1105,7 +1105,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
return new_invals, [*carry_out, *ys_out]
def scan_bind(*args, **params):
if config.jax_enable_checks:
if config.enable_checks.value:
avals = _map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_scan_typecheck(True, *in_atoms, **params)
@ -1190,7 +1190,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
"""
if not (callable(body_fun) and callable(cond_fun)):
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
if config.jax_disable_jit:
if config.disable_jit.value:
try:
val = init_val
while cond_fun(val):
@ -1931,7 +1931,7 @@ def fori_loop(lower, upper, body_fun, init_val):
use_scan = False
if use_scan:
if config.jax_disable_jit and upper_ == lower_:
if config.disable_jit.value and upper_ == lower_:
# non-jit implementation of scan does not support length=0
return init_val

View File

@ -25,6 +25,7 @@ import numpy as np
import jax
from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -156,7 +157,7 @@ def dynamic_slice(
- :func:`jax.lax.dynamic_index_in_dim`
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
if jax.config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
else:
dynamic_sizes = []
@ -1090,7 +1091,7 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
msg = ("slice start_indices must be greater than or equal to zero, "
"got start_indices of {}.")
raise TypeError(msg.format(start_indices))
if not jax.config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
if not all(map(operator.ge, limit_indices, start_indices)):
msg = ("slice limit_indices must be greater than or equal to start_indices,"
" got start_indices {} and limit_indices {}.")

View File

@ -24,6 +24,7 @@ import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import util
@ -446,7 +447,7 @@ def softmax(x: ArrayLike,
See also:
:func:`log_softmax`
"""
if jax.config.jax_softmax_custom_jvp:
if config.softmax_custom_jvp.value:
# mypy is confused by the `functools.partial` application in the definition
# of `_softmax` and incorrectly concludes that `_softmax` returns
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.

View File

@ -47,6 +47,7 @@ from jax import lax
from jax.tree_util import tree_leaves, tree_flatten, tree_map
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -56,7 +57,6 @@ from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
@ -2321,7 +2321,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
def arange(start: DimSize, stop: Optional[DimSize] = None,
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "arange")
if not jax.config.jax_dynamic_shapes:
if not config.dynamic_shapes.value:
util.check_arraylike("arange", start)
if stop is None and step is None:
start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
@ -3459,7 +3459,7 @@ def _einsum(
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
# due to opt_einsum
assert jax.config.jax_dynamic_shapes or all(
assert config.dynamic_shapes.value or all(
name in lhs_names and name in rhs_names and
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
for name in contracted_names), (
@ -4309,7 +4309,7 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
return result
# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
if config.dynamic_shapes.value and arr.ndim > 0:
try: aval = core.get_aval(idx)
except: pass
else:

View File

@ -37,6 +37,7 @@ from typing import Any, NamedTuple, Protocol, Union
import jax
from jax._src import core
from jax._src import config
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
@ -488,7 +489,7 @@ class Compiled(Stage):
# which might conflict here.
params = args[0]
args = args[1:]
if jax.config.jax_dynamic_shapes:
if config.dynamic_shapes.value:
raise NotImplementedError
if params.no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())

View File

@ -17,13 +17,14 @@ import os
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax import config
from jax.experimental.jax2tf.examples import saved_model_main
from jax.experimental.jax2tf.tests import tf_test_util
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
FLAGS = flags.FLAGS
@ -48,8 +49,8 @@ class SavedModelMainTest(tf_test_util.JaxToTfTestCase):
model="mnist_flax",
serving_batch_size=-1):
if (serving_batch_size == -1 and
config.jax2tf_default_native_serialization and
not config.jax_dynamic_shapes):
config.jax2tf_default_native_serialization.value and
not config.dynamic_shapes.value):
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
FLAGS.model = model
FLAGS.model_classifier_layer = True

View File

@ -29,7 +29,6 @@ import numpy as np
import jax
from jax import lax
from jax import config
from jax import custom_derivatives
from jax import random
from jax import numpy as jnp
@ -45,6 +44,7 @@ from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -322,7 +322,7 @@ def convert(fun_jax: Callable,
if not enable_xla:
native_serialization = False
else:
native_serialization = config.jax2tf_default_native_serialization
native_serialization = config.jax2tf_default_native_serialization.value
if native_serialization and not enable_xla:
raise ValueError(
@ -1180,7 +1180,7 @@ class TensorFlowTracer(core.Tracer):
if isinstance(val, (tf.Tensor, tf.Variable)):
val_shape = val.shape
if config.jax_enable_checks:
if config.enable_checks.value:
assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than {val_shape=}"
# To compare types, we must handle float0 in JAX and x64 in TF
if phys_aval.dtype == dtypes.float0:
@ -1335,7 +1335,7 @@ class TensorFlowTrace(core.Trace):
# Check that the impl rule returned a value of expected shape and dtype
# TODO: adapt this to match polymorphic shapes
if config.jax_enable_checks:
if config.enable_checks.value:
if primitive.multiple_results:
for o, expected_aval in zip(out, out_aval): # type: ignore
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
@ -2538,7 +2538,7 @@ tf_impl_with_avals[lax.reduce_p] = _reduce
def _cumred(lax_reduce_fn: Callable,
lax_reduce_window_fn: Callable,
extra_name_stack: str):
if config.jax2tf_associative_scan_reductions:
if config.jax2tf_associative_scan_reductions.value:
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
lax_reduce_fn),
multiple_results=False,

View File

@ -48,12 +48,12 @@ from absl import testing
import numpy as np
import jax
from jax import config
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax._src import ad_util
from jax._src import config
from jax._src import dispatch
from jax._src import prng
from jax._src import test_util as jtu
@ -67,7 +67,7 @@ from jax._src import random as jax_random
# then the test file has to import jtu first (to define the flags) which is not
# desired if the test file is outside of this project (we don't want a
# dependency on jtu outside of jax repo).
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
FLAGS = config.FLAGS
Rng = Any # A random number generator
@ -2676,7 +2676,7 @@ for dtype in (np.float32, np.float64):
def wrap_and_split():
key = jax.random.key(42)
if jax.config.jax_enable_custom_prng:
if config.enable_custom_prng.value:
key = prng.random_wrap(key, impl=jax.random.default_prng_impl())
result = jax.random.split(key, 2)
return prng.random_unwrap(result)
@ -3314,7 +3314,7 @@ for padding, lhs_dilation, rhs_dilation in [
rhs_dilation=rhs_dilation)
key_types = [((4,), np.uint32)]
if config.jax_enable_x64:
if config.enable_x64.value:
key_types.append(((2,), np.uint64))
for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY,

View File

@ -63,8 +63,8 @@ from absl.testing import parameterized
import jax
from jax import dtypes
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax import config
from jax.experimental import jax2tf
from jax.interpreters import mlir
from jax._src.interpreters import xla
@ -72,7 +72,7 @@ from jax._src.interpreters import xla
import numpy as np
import tensorflow as tf # type: ignore[import]
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# Import after parsing flags
from jax.experimental.jax2tf.tests import tf_test_util
@ -114,7 +114,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
func_jax = harness.dyn_fun
args = harness.dyn_args_maker(self.rng())
enable_xla = harness.params.get("enable_xla", True)
if config.jax2tf_default_native_serialization and not enable_xla:
if config.jax2tf_default_native_serialization.value and not enable_xla:
raise unittest.SkipTest("native_serialization not supported with enable_xla=False")
if ("eigh" == harness.group_name and
@ -122,7 +122,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
device == "tpu"):
raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32")
if (config.jax2tf_default_native_serialization and
if (config.jax2tf_default_native_serialization.value and
device == "gpu" and
"lu" in harness.fullname):
raise unittest.SkipTest("b/269388847: lu failures on GPU")
@ -130,7 +130,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
def skipCustomCallTest(target: str):
raise unittest.SkipTest(
f"TODO(b/272239584): custom call target not guaranteed stable: {target}")
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
if device == "gpu":
if "custom_linear_solve_" in harness.fullname:
skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched")
@ -146,7 +146,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
enable_xla=enable_xla)
except Exception as e:
# TODO(b/264596006): custom calls are not registered properly with TF in OSS
if (config.jax2tf_default_native_serialization and
if (config.jax2tf_default_native_serialization.value and
"does not work with custom calls" in str(e)):
logging.warning("Suppressing error %s", e)
raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF")
@ -257,7 +257,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# The CPU has more supported types, and harnesses
self.assertEqual("cpu", jtu.device_under_test())
self.assertTrue(
config.x64_enabled,
config.enable_x64.value,
"Documentation generation must be run with JAX_ENABLE_X64=1")
with open(
@ -299,7 +299,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
y = np.int32(3)
self.ConvertAndCompare(jnp.floor_divide, x, y)
expected = jnp.floor_divide(x, y)
if not config.jax2tf_default_native_serialization:
if not config.jax2tf_default_native_serialization.value:
# With native serialization TF1 seems to want to run the converted code
# on the CPU even when the default backend is the TPU.
# Try it with TF 1 as well (#5831)

View File

@ -30,8 +30,8 @@ import unittest
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax import config
from jax import lax
from jax.experimental import jax2tf
from jax.experimental import pjit
@ -47,7 +47,7 @@ import numpy as np
import tensorflow as tf # type: ignore[import]
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# Must come after initializing the flags
from jax.experimental.jax2tf.tests import tf_test_util
@ -225,7 +225,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
# Annotation count for the input
count_in_P = 1 if in_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
# With native serialization even unspecified in_shardings turn into replicated
count_in_replicated = 1 if in_shardings in [None, "missing"] else 0
else:
@ -400,14 +400,14 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
# Annotation count for the primal input and the grad output
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
# With native serialization even unspecified shardings turn into replicated
count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0
else:
count_in_replicated = self.GEQ(2) if in_shardings is None else 0
# Annotation count for the contangent input
count_out_P = self.GEQ(1) if out_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
# With native serialization even unspecified shardings turn into replicated
count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0
else:
@ -479,7 +479,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
"nested_pjit_sharded", "nested_pjit_replicated")
])
def test_pjit_eager_error(self, func="pjit_sharded"):
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("There is no error in eager mode for native serialization")
# Define some test functions

View File

@ -29,9 +29,9 @@ from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax import config
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax._src import config
from jax._src import xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -177,16 +177,14 @@ class JaxToTfTestCase(jtu.JaxTestCase):
# We run the tests using the maximum version supported, even though
# the default serialization version may be held back for a while to
# ensure compatibility
version = config.jax_serialization_version
self.addCleanup(functools.partial(config.update,
"jax_serialization_version", version))
version = config.jax_serialization_version.value
if self.use_max_serialization_version:
# Use the largest supported by both export and tfxla.call_module
version = min(export.maximum_supported_serialization_version,
tfxla.call_module_maximum_supported_version())
self.assertGreaterEqual(version,
export.minimum_supported_serialization_version)
config.update("jax_serialization_version", version)
self.enter_context(config.jax_serialization_version(version))
logging.info(
"Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)",
version,
@ -203,7 +201,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
def to_numpy_dtype(dt):
return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype
if not config.x64_enabled and canonicalize_dtypes:
if not config.enable_x64.value and canonicalize_dtypes:
self.assertEqual(
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
@ -410,7 +408,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
# graph. We count the number of characters in the textual representation
# of the constant.
f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def()
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
# This way of finding constants may be brittle, if the constant representation
# contains >. It seems tobe hex-encoded, so this may be safe.
large_consts = [m for m in re.findall(r"dense<([^>]+)>", str(f_tf_graph)) if len(m) >= at_least]

View File

@ -26,12 +26,12 @@ import numpy as np
import jax
from jax import numpy as jnp
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = jtu.numpy_version()
@ -705,7 +705,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
@unittest.skipIf(not config.enable_x64.value, "test requires X64")
@jtu.run_on_devices("cpu") # test is for CPU float64 precision
def testPercentilePrecision(self):
# Regression test for https://github.com/google/jax/issues/8513

View File

@ -27,11 +27,11 @@ from jax import lax
from jax.tree_util import register_pytree_node_class
import jax._src.scipy.sparse.linalg as sp_linalg
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
float_types = jtu.dtypes.floating
@ -100,7 +100,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
preconditioner=[None, 'identity', 'exact', 'random'],
)
def test_cg_against_scipy(self, shape, dtype, preconditioner):
if not config.x64_enabled:
if not config.enable_x64.value:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
@ -221,7 +221,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
)
def test_bicgstab_against_scipy(
self, shape, dtype, preconditioner):
if not config.jax_enable_x64:
if not config.enable_x64.value:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
@ -326,7 +326,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
)
def test_gmres_against_scipy(
self, shape, dtype, preconditioner, solve_method):
if not config.x64_enabled:
if not config.enable_x64.value:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
@ -437,7 +437,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
"""
The Arnoldi decomposition within GMRES is correct.
"""
if not config.x64_enabled:
if not config.enable_x64.value:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import scipy.stats
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import ad_checkpoint
@ -33,8 +34,7 @@ from jax import random
import jax
import jax.numpy as jnp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class NNFunctionsTest(jtu.JaxTestCase):
@ -143,7 +143,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
# TODO(mattjj): include log_softmax in these extra tests if/when we add a
# custom_jvp rule for it (since otherwise it doesn't pass the numerical
# checks below).
if fn is nn.softmax and config.jax_softmax_custom_jvp:
if fn is nn.softmax and config.softmax_custom_jvp.value:
g_fun = lambda x: jnp.take(fn(x, where=m, initial=-jnp.inf),
jnp.array([0, 2, 3]))
jtu.check_grads(g_fun, (x,), order=2)
@ -153,7 +153,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3)
def testSoftmaxGradResiduals(self):
if not jax.config.jax_softmax_custom_jvp:
if not config.softmax_custom_jvp.value:
raise unittest.SkipTest("only applies when upgrade flag enabled")
x = jnp.array([5.5, 1.3, -4.2, 0.9])
res = ad_checkpoint.saved_residuals(nn.softmax, x)

View File

@ -30,6 +30,7 @@ import concurrent.futures
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import config
from jax._src import test_util as jtu
from jax import dtypes
from jax import stages
@ -59,8 +60,7 @@ from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, safe_zip
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
prev_xla_flags = None
@ -955,7 +955,7 @@ class PJitTest(jtu.BufferDonationTestCase):
@jtu.with_mesh([('x', 2)])
def testWithCustomPRNGKey(self):
if not config.jax_enable_custom_prng:
if not config.enable_custom_prng.value:
raise unittest.SkipTest("test requires jax_enable_custom_prng")
key = prng.seed_with_impl(prng.rbg_prng_impl, 87)
# Make sure this doesn't crash

View File

@ -22,8 +22,8 @@ import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
@ -40,7 +40,7 @@ import jax.numpy as jnp
from jax.sharding import Mesh
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _format_multiline(text):
@ -197,7 +197,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
@with_pure_and_io_callbacks
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
if config.jax_enable_x64:
if config.enable_x64.value:
raise unittest.SkipTest("Test only needed when 64-bit mode disabled.")
@jax.jit

View File

@ -29,6 +29,7 @@ from jax import grad
from jax import lax
from jax import numpy as jnp
from jax import random
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
@ -36,8 +37,7 @@ from jax import vmap
from jax._src import prng as prng_internal
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
@ -63,7 +63,7 @@ class LaxRandomTest(jtu.JaxTestCase):
fail_prob = 0.003 if samples.dtype == jnp.bfloat16 else 0.01
# TODO(frostig): This reads enable_custom_prng as a proxy for
# whether RBG keys may be involved, but that's no longer exact.
if config.jax_enable_custom_prng and samples.dtype == jnp.bfloat16:
if config.enable_custom_prng.value and samples.dtype == jnp.bfloat16:
return
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
@ -382,7 +382,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=[np.float64], # NOTE: KS test fails with float32
)
def testBeta(self, a, b, dtype):
if not config.x64_enabled:
if not config.enable_x64.value:
raise SkipTest("skip test except on X64")
key = self.make_key(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
@ -824,7 +824,7 @@ class LaxRandomTest(jtu.JaxTestCase):
assert np.unique(keys, axis=0).shape[0] == 2
def testStaticShapeErrors(self):
if config.jax_disable_jit:
if config.disable_jit.value:
raise SkipTest("test only relevant when jit enabled")
@jax.jit
@ -991,7 +991,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_prng_jit_invariance(self, seed, type_):
if type_ == "int" and seed == (1 << 64) - 1:
self.skipTest("Expected failure: Python int too large.")
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
if not config.enable_x64.value and seed > np.iinfo(np.int32).max:
self.skipTest("Expected failure: Python int too large.")
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
args_maker = lambda: [type_(seed)]

View File

@ -30,6 +30,7 @@ from jax import lax
from jax import numpy as jnp
from jax import random
from jax import tree_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
@ -39,8 +40,7 @@ from jax.interpreters import xla
from jax._src import random as jax_random
from jax._src import prng as prng_internal
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
PRNG_IMPLS = list(prng_internal.prngs.items())
@ -247,7 +247,7 @@ class PrngTest(jtu.JaxTestCase):
finally:
xla.apply_primitive = apply_primitive
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBits(self, make_key):
# Test specific outputs to ensure consistent random values between JAX versions.
@ -280,7 +280,7 @@ class PrngTest(jtu.JaxTestCase):
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
if config.x64_enabled:
if config.enable_x64.value:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
else:
@ -315,11 +315,11 @@ class PrngTest(jtu.JaxTestCase):
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32')
expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32')
self.assertEqual(bits64.shape, (3,))
self.assertEqual(bits64.dtype, expected_dtype)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBitsViewProperty(self, make_key):
# TODO: add 64-bit if it ever supports this property.
@ -338,7 +338,7 @@ class PrngTest(jtu.JaxTestCase):
@jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case, make_key):
"""
@ -355,9 +355,9 @@ class PrngTest(jtu.JaxTestCase):
* Considering adding a flag that reverts the new behavior, made
available for a deprecation window's amount of time.
"""
if config.x64_enabled and case.on_x64 == OnX64.SKIP:
if config.enable_x64.value:
self.skipTest("test produces different values when jax_enable_x64=True")
if not config.x64_enabled and case.on_x64 == OnX64.ONLY:
if not config.enable_x64.value:
self.skipTest("test only valid when jax_enable_x64=True")
with jax.default_prng_impl(case.prng_impl):
func = getattr(random, case.name)
@ -368,7 +368,7 @@ class PrngTest(jtu.JaxTestCase):
actual = func(key, **case.params, shape=case.shape)
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testPRNGValues(self, make_key):
# Test to ensure consistent random values between JAX versions
@ -376,7 +376,7 @@ class PrngTest(jtu.JaxTestCase):
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
dtypes.canonicalize_dtype(jnp.int_))
if config.x64_enabled:
if config.enable_x64.value:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int64'),
np.array([[7, 2, 6],
@ -407,7 +407,7 @@ class PrngTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, msg):
random.bits(make_key(0), (3, 4), np.dtype('float16'))
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_fold_in_symmetry(self, make_key):
with jax.default_prng_impl('threefry2x32'):
@ -420,7 +420,7 @@ class PrngTest(jtu.JaxTestCase):
self.assertArraysEqual(f2, s2)
self.assertArraysEqual(f3, s3)
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
# See https://github.com/google/jax/issues/7708
@ -435,7 +435,7 @@ class PrngTest(jtu.JaxTestCase):
self.assertArraysEqual(f2, s2)
self.assertArraysEqual(f3, s3)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
def test_loggamma_nan_corner_case(self):
# regression test for https://github.com/google/jax/issues/17922
# This particular key previously led to NaN output.
@ -459,20 +459,20 @@ class PrngTest(jtu.JaxTestCase):
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]},
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]},
]
for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS]
])
@ -482,7 +482,7 @@ class PrngTest(jtu.JaxTestCase):
maker = lambda k: random.key_data(jax.jit(make_key)(k))
else:
maker = lambda k: random.key_data(make_key(k))
if (jit and typ is int and not config.x64_enabled and
if (jit and typ is int and not config.enable_x64.value and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
# We expect an error to be raised.
# NOTE: we check 'if jit' because some people rely on builtin int seeds
@ -607,7 +607,7 @@ class KeyArrayTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, "Cannot interpret"):
jnp.issubdtype(key, dtypes.prng_key)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
@skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag')
def test_construction_upgrade_flag(self):
key = random.PRNGKey(42)
self.assertIsInstance(key, jax_random.PRNGKeyArray)
@ -1059,7 +1059,7 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))
impl = config.jax_default_prng_impl
impl = config.default_prng_impl.value
key3 = jax.random.wrap_key_data(data, impl=impl)
self.assertEqual(key1.dtype, key3.dtype)
self.assertArraysEqual(jax.random.key_data(key1),

View File

@ -28,9 +28,9 @@ import numpy as np
import jax
from jax import lax
from jax import config
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -43,7 +43,7 @@ import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -589,7 +589,7 @@ class ShardMapTest(jtu.JaxTestCase):
return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16,
dtype=jnp.int32)
pspec = P('x') if config.jax_enable_custom_prng else P('x', None)
pspec = P('x') if config.enable_custom_prng.value else P('x', None)
g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec)
_ = g(sharded_rng) # don't crash!

View File

@ -24,8 +24,8 @@ import jax
from jax import random
from jax import lax
from jax._src import core
from jax._src import config
from jax._src import linear_util as lu
from jax import config
from jax._src.interpreters import partial_eval as pe
from jax._src import test_util as jtu
from jax._src.util import tuple_insert
@ -48,7 +48,7 @@ from jax._src.state.primitives import (get_p, swap_p, addupdate_p,
from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
AccumEffect, AbstractRef)
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class StatePrimitivesTest(jtu.JaxTestCase):
@ -413,9 +413,9 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims,
idx_bdims, out_bdim, op):
float_ = (jnp.dtype('float64') if jax.config.jax_enable_x64 else
float_ = (jnp.dtype('float64') if config.enable_x64.value else
jnp.dtype('float32'))
int_ = (jnp.dtype('int64') if jax.config.jax_enable_x64 else
int_ = (jnp.dtype('int64') if config.enable_x64.value else
jnp.dtype('int32'))
axis_size = 7
out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b)