mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. This is a follow up to #18008. PiperOrigin-RevId: 572587137
This commit is contained in:
parent
04422b8aa4
commit
2f70ae700a
@ -28,9 +28,8 @@ from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
import typing
|
||||
from typing import (Any, Callable, Literal,
|
||||
NamedTuple, Optional, TypeVar, Union,
|
||||
overload, cast)
|
||||
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
|
||||
cast)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -42,6 +41,7 @@ from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
|
||||
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
|
||||
prefix_errors, generate_key_paths)
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
@ -53,7 +53,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import pjit
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.core import eval_jaxpr
|
||||
from jax._src.core import eval_jaxpr, ShapedArray
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
@ -73,23 +73,12 @@ from jax._src import tree_util
|
||||
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
|
||||
from jax._src import util
|
||||
|
||||
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src.config import (
|
||||
config,
|
||||
disable_jit as _disable_jit,
|
||||
debug_nans as config_debug_nans,
|
||||
debug_infs as config_debug_infs,
|
||||
_thread_local_state as config_thread_local_state,
|
||||
explicit_device_put_scope as config_explicit_device_put_scope,
|
||||
explicit_device_get_scope as config_explicit_device_get_scope)
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -121,27 +110,28 @@ def _nan_check_posthook(fun, args, kwargs, output):
|
||||
dispatch.check_special(pjit.pjit_p.name, buffers)
|
||||
except FloatingPointError:
|
||||
# compiled_fun can only raise in this case
|
||||
assert config.jax_debug_nans or config.jax_debug_infs
|
||||
assert config.debug_nans.value or config.debug_infs.value
|
||||
print("Invalid nan value encountered in the output of a C++-jit/pmap "
|
||||
"function. Calling the de-optimized version.")
|
||||
fun._cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
|
||||
def _update_debug_special_global(_):
|
||||
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
|
||||
if (config.config._read("jax_debug_nans") or
|
||||
config.config._read("jax_debug_infs")):
|
||||
jax_jit.global_state().post_hook = _nan_check_posthook
|
||||
else:
|
||||
jax_jit.global_state().post_hook = None
|
||||
|
||||
def _update_debug_special_thread_local(_):
|
||||
if (getattr(config_thread_local_state, "jax_debug_nans", False) or
|
||||
getattr(config_thread_local_state, "jax_debug_infs", False)):
|
||||
if (getattr(config._thread_local_state, "jax_debug_nans", False) or
|
||||
getattr(config._thread_local_state, "jax_debug_infs", False)):
|
||||
jax_jit.thread_local_state().post_hook = _nan_check_posthook
|
||||
else:
|
||||
jax_jit.thread_local_state().post_hook = None
|
||||
|
||||
config_debug_nans._add_hooks(_update_debug_special_global,
|
||||
config.debug_nans._add_hooks(_update_debug_special_global,
|
||||
_update_debug_special_thread_local)
|
||||
config_debug_infs._add_hooks(_update_debug_special_global,
|
||||
config.debug_infs._add_hooks(_update_debug_special_global,
|
||||
_update_debug_special_thread_local)
|
||||
|
||||
|
||||
@ -376,7 +366,7 @@ def disable_jit(disable: bool = True):
|
||||
Value of y is [2 4 6]
|
||||
[5 7 9]
|
||||
"""
|
||||
with _disable_jit(disable):
|
||||
with config.disable_jit(disable):
|
||||
yield
|
||||
|
||||
|
||||
@ -1579,7 +1569,7 @@ def pmap(
|
||||
|
||||
# TODO(yashkatariya): Move this out after shard_map is out of experimental and
|
||||
# in _src
|
||||
if config.jax_pmap_shmap_merge:
|
||||
if config.pmap_shmap_merge.value:
|
||||
from jax.experimental.shard_map import pmap
|
||||
return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes,
|
||||
static_broadcasted_argnums=static_broadcasted_argnums,
|
||||
@ -1670,7 +1660,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
dyn_args, dyn_in_axes = args, in_axes
|
||||
args, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
|
||||
if donate_tuple and not config.jax_debug_nans:
|
||||
if donate_tuple and not config.debug_nans.value:
|
||||
donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args)
|
||||
@ -2530,7 +2520,7 @@ def device_put(
|
||||
This function is always asynchronous, i.e. returns immediately without
|
||||
blocking the calling Python thread until any transfers are completed.
|
||||
"""
|
||||
with config_explicit_device_put_scope():
|
||||
with config.explicit_device_put_scope():
|
||||
if ((device is None or
|
||||
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
|
||||
(src is None or
|
||||
@ -2624,7 +2614,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices))
|
||||
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
with config.explicit_device_put_scope():
|
||||
return tree_map(_device_put_sharded, *shards)
|
||||
|
||||
|
||||
@ -2674,7 +2664,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
assert len(xla.aval_to_xla_shapes(aval)) == 1
|
||||
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
with config.explicit_device_put_scope():
|
||||
return tree_map(_device_put_replicated, x)
|
||||
|
||||
|
||||
@ -2720,7 +2710,7 @@ def device_get(x: Any):
|
||||
- device_put_sharded
|
||||
- device_put_replicated
|
||||
"""
|
||||
with config_explicit_device_get_scope():
|
||||
with config.explicit_device_get_scope():
|
||||
for y in tree_leaves(x):
|
||||
try:
|
||||
y.copy_to_host_async()
|
||||
|
@ -20,9 +20,9 @@ from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax import config
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
register_pytree_node, Partial)
|
||||
from jax._src import core
|
||||
@ -200,7 +200,7 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
|
||||
ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
|
||||
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
||||
# TODO(mattjj): add back these checks for dynamic shapes
|
||||
# if config.jax_enable_checks:
|
||||
# if config.enable_checks.value:
|
||||
# ct_aval = core.get_aval(ct_env[v])
|
||||
# joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
|
||||
# assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)
|
||||
@ -458,7 +458,7 @@ class JVPTracer(Tracer):
|
||||
__slots__ = ['primal', 'tangent']
|
||||
|
||||
def __init__(self, trace, primal, tangent):
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
_primal_tangent_shapes_match(primal, tangent)
|
||||
self._trace = trace
|
||||
self.primal = primal
|
||||
@ -624,7 +624,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
if update_params:
|
||||
params = update_params(params, map(is_undefined_primal, args),
|
||||
[type(x) is not Zero for x in ct])
|
||||
if config.jax_dynamic_shapes:
|
||||
if config.dynamic_shapes.value:
|
||||
# TODO(mattjj,dougalm): handle consts, for now assume just args
|
||||
which_lin = [is_undefined_primal(x) for x in args]
|
||||
res_invars, _ = partition_list(which_lin, call_jaxpr.invars)
|
||||
|
@ -22,7 +22,7 @@ from typing import Any, Callable, Union
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import linear_util as lu
|
||||
@ -318,7 +318,7 @@ class BatchTracer(Tracer):
|
||||
|
||||
def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
|
||||
source_info: source_info_util.SourceInfo | None = None):
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
|
||||
if type(batch_dim) is int:
|
||||
aval = raise_to_shaped(core.get_aval(val))
|
||||
@ -416,7 +416,7 @@ class BatchTrace(Trace):
|
||||
return frame
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
if config.jax_dynamic_shapes:
|
||||
if config.dynamic_shapes.value:
|
||||
primitive.abstract_eval(*(t.aval for t in tracers), **params)
|
||||
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
is_axis_primitive = primitive in axis_primitive_batchers
|
||||
|
@ -22,9 +22,9 @@ import itertools
|
||||
import operator
|
||||
from typing import Callable
|
||||
|
||||
from jax import config
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
@ -129,7 +129,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
hi = np.array(len(branches) - 1, np.int32)
|
||||
index = lax.clamp(lo, index, hi)
|
||||
|
||||
if (config.jax_disable_jit and
|
||||
if (config.disable_jit.value and
|
||||
isinstance(core.get_aval(index), ConcreteArray)):
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
@ -226,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred_dtype))
|
||||
|
||||
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
|
||||
if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray):
|
||||
if pred:
|
||||
return true_fun(*operands)
|
||||
else:
|
||||
@ -806,7 +806,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear):
|
||||
return jaxpr0.out_avals, joined_effects
|
||||
|
||||
def cond_bind(*args, branches, linear):
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
avals = map(core.get_aval, args)
|
||||
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
|
||||
_cond_typecheck(True, *in_atoms, branches=branches, linear=linear)
|
||||
|
@ -22,9 +22,9 @@ from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
import jax
|
||||
import weakref
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax import config # type: ignore[no-redef]
|
||||
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
tree_map, tree_flatten_with_path, keystr)
|
||||
@ -214,7 +214,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
else:
|
||||
length, = unique_lengths
|
||||
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
if length == 0:
|
||||
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
|
||||
carry = init
|
||||
@ -859,7 +859,7 @@ def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
|
||||
if config.enable_checks.value: core.check_jaxpr(jaxpr.jaxpr)
|
||||
|
||||
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
|
||||
new_params = dict(eqn.params, num_consts=sum(used_consts),
|
||||
@ -1105,7 +1105,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
|
||||
return new_invals, [*carry_out, *ys_out]
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
avals = _map(core.get_aval, args)
|
||||
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
|
||||
_scan_typecheck(True, *in_atoms, **params)
|
||||
@ -1190,7 +1190,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
"""
|
||||
if not (callable(body_fun) and callable(cond_fun)):
|
||||
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
try:
|
||||
val = init_val
|
||||
while cond_fun(val):
|
||||
@ -1931,7 +1931,7 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
use_scan = False
|
||||
|
||||
if use_scan:
|
||||
if config.jax_disable_jit and upper_ == lower_:
|
||||
if config.disable_jit.value and upper_ == lower_:
|
||||
# non-jit implementation of scan does not support length=0
|
||||
return init_val
|
||||
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
import jax
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
@ -156,7 +157,7 @@ def dynamic_slice(
|
||||
- :func:`jax.lax.dynamic_index_in_dim`
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
if config.dynamic_shapes.value:
|
||||
dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
|
||||
else:
|
||||
dynamic_sizes = []
|
||||
@ -1090,7 +1091,7 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
||||
msg = ("slice start_indices must be greater than or equal to zero, "
|
||||
"got start_indices of {}.")
|
||||
raise TypeError(msg.format(start_indices))
|
||||
if not jax.config.jax_dynamic_shapes:
|
||||
if not config.dynamic_shapes.value:
|
||||
if not all(map(operator.ge, limit_indices, start_indices)):
|
||||
msg = ("slice limit_indices must be greater than or equal to start_indices,"
|
||||
" got start_indices {} and limit_indices {}.")
|
||||
|
@ -24,6 +24,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import custom_jvp
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
@ -446,7 +447,7 @@ def softmax(x: ArrayLike,
|
||||
See also:
|
||||
:func:`log_softmax`
|
||||
"""
|
||||
if jax.config.jax_softmax_custom_jvp:
|
||||
if config.softmax_custom_jvp.value:
|
||||
# mypy is confused by the `functools.partial` application in the definition
|
||||
# of `_softmax` and incorrectly concludes that `_softmax` returns
|
||||
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
|
||||
|
@ -47,6 +47,7 @@ from jax import lax
|
||||
from jax.tree_util import tree_leaves, tree_flatten, tree_map
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
@ -56,7 +57,6 @@ from jax._src.core import ShapedArray, ConcreteArray
|
||||
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
|
||||
_sort_le_comparator, PrecisionLike)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
@ -2321,7 +2321,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
|
||||
def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "arange")
|
||||
if not jax.config.jax_dynamic_shapes:
|
||||
if not config.dynamic_shapes.value:
|
||||
util.check_arraylike("arange", start)
|
||||
if stop is None and step is None:
|
||||
start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
|
||||
@ -3459,7 +3459,7 @@ def _einsum(
|
||||
|
||||
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
|
||||
# due to opt_einsum
|
||||
assert jax.config.jax_dynamic_shapes or all(
|
||||
assert config.dynamic_shapes.value or all(
|
||||
name in lhs_names and name in rhs_names and
|
||||
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
|
||||
for name in contracted_names), (
|
||||
@ -4309,7 +4309,7 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
return result
|
||||
|
||||
# TODO(mattjj,dougalm): expand dynamic shape indexing support
|
||||
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
|
||||
if config.dynamic_shapes.value and arr.ndim > 0:
|
||||
try: aval = core.get_aval(idx)
|
||||
except: pass
|
||||
else:
|
||||
|
@ -37,6 +37,7 @@ from typing import Any, NamedTuple, Protocol, Union
|
||||
import jax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
@ -488,7 +489,7 @@ class Compiled(Stage):
|
||||
# which might conflict here.
|
||||
params = args[0]
|
||||
args = args[1:]
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
if config.dynamic_shapes.value:
|
||||
raise NotImplementedError
|
||||
if params.no_kwargs and kwargs:
|
||||
kws = ', '.join(kwargs.keys())
|
||||
|
@ -17,13 +17,14 @@ import os
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
|
||||
from jax.experimental.jax2tf.examples import saved_model_main
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
@ -48,8 +49,8 @@ class SavedModelMainTest(tf_test_util.JaxToTfTestCase):
|
||||
model="mnist_flax",
|
||||
serving_batch_size=-1):
|
||||
if (serving_batch_size == -1 and
|
||||
config.jax2tf_default_native_serialization and
|
||||
not config.jax_dynamic_shapes):
|
||||
config.jax2tf_default_native_serialization.value and
|
||||
not config.dynamic_shapes.value):
|
||||
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
|
||||
FLAGS.model = model
|
||||
FLAGS.model_classifier_layer = True
|
||||
|
@ -29,7 +29,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax import custom_derivatives
|
||||
from jax import random
|
||||
from jax import numpy as jnp
|
||||
@ -45,6 +44,7 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
@ -322,7 +322,7 @@ def convert(fun_jax: Callable,
|
||||
if not enable_xla:
|
||||
native_serialization = False
|
||||
else:
|
||||
native_serialization = config.jax2tf_default_native_serialization
|
||||
native_serialization = config.jax2tf_default_native_serialization.value
|
||||
|
||||
if native_serialization and not enable_xla:
|
||||
raise ValueError(
|
||||
@ -1180,7 +1180,7 @@ class TensorFlowTracer(core.Tracer):
|
||||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||||
val_shape = val.shape
|
||||
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than {val_shape=}"
|
||||
# To compare types, we must handle float0 in JAX and x64 in TF
|
||||
if phys_aval.dtype == dtypes.float0:
|
||||
@ -1335,7 +1335,7 @@ class TensorFlowTrace(core.Trace):
|
||||
|
||||
# Check that the impl rule returned a value of expected shape and dtype
|
||||
# TODO: adapt this to match polymorphic shapes
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
if primitive.multiple_results:
|
||||
for o, expected_aval in zip(out, out_aval): # type: ignore
|
||||
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
|
||||
@ -2538,7 +2538,7 @@ tf_impl_with_avals[lax.reduce_p] = _reduce
|
||||
def _cumred(lax_reduce_fn: Callable,
|
||||
lax_reduce_window_fn: Callable,
|
||||
extra_name_stack: str):
|
||||
if config.jax2tf_associative_scan_reductions:
|
||||
if config.jax2tf_associative_scan_reductions.value:
|
||||
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
|
||||
lax_reduce_fn),
|
||||
multiple_results=False,
|
||||
|
@ -48,12 +48,12 @@ from absl import testing
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
@ -67,7 +67,7 @@ from jax._src import random as jax_random
|
||||
# then the test file has to import jtu first (to define the flags) which is not
|
||||
# desired if the test file is outside of this project (we don't want a
|
||||
# dependency on jtu outside of jax repo).
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
Rng = Any # A random number generator
|
||||
@ -2676,7 +2676,7 @@ for dtype in (np.float32, np.float64):
|
||||
|
||||
def wrap_and_split():
|
||||
key = jax.random.key(42)
|
||||
if jax.config.jax_enable_custom_prng:
|
||||
if config.enable_custom_prng.value:
|
||||
key = prng.random_wrap(key, impl=jax.random.default_prng_impl())
|
||||
result = jax.random.split(key, 2)
|
||||
return prng.random_unwrap(result)
|
||||
@ -3314,7 +3314,7 @@ for padding, lhs_dilation, rhs_dilation in [
|
||||
rhs_dilation=rhs_dilation)
|
||||
|
||||
key_types = [((4,), np.uint32)]
|
||||
if config.jax_enable_x64:
|
||||
if config.enable_x64.value:
|
||||
key_types.append(((2,), np.uint64))
|
||||
|
||||
for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY,
|
||||
|
@ -63,8 +63,8 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
@ -72,7 +72,7 @@ from jax._src.interpreters import xla
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
@ -114,7 +114,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
func_jax = harness.dyn_fun
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
enable_xla = harness.params.get("enable_xla", True)
|
||||
if config.jax2tf_default_native_serialization and not enable_xla:
|
||||
if config.jax2tf_default_native_serialization.value and not enable_xla:
|
||||
raise unittest.SkipTest("native_serialization not supported with enable_xla=False")
|
||||
|
||||
if ("eigh" == harness.group_name and
|
||||
@ -122,7 +122,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
device == "tpu"):
|
||||
raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32")
|
||||
|
||||
if (config.jax2tf_default_native_serialization and
|
||||
if (config.jax2tf_default_native_serialization.value and
|
||||
device == "gpu" and
|
||||
"lu" in harness.fullname):
|
||||
raise unittest.SkipTest("b/269388847: lu failures on GPU")
|
||||
@ -130,7 +130,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
def skipCustomCallTest(target: str):
|
||||
raise unittest.SkipTest(
|
||||
f"TODO(b/272239584): custom call target not guaranteed stable: {target}")
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
if device == "gpu":
|
||||
if "custom_linear_solve_" in harness.fullname:
|
||||
skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched")
|
||||
@ -146,7 +146,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
enable_xla=enable_xla)
|
||||
except Exception as e:
|
||||
# TODO(b/264596006): custom calls are not registered properly with TF in OSS
|
||||
if (config.jax2tf_default_native_serialization and
|
||||
if (config.jax2tf_default_native_serialization.value and
|
||||
"does not work with custom calls" in str(e)):
|
||||
logging.warning("Suppressing error %s", e)
|
||||
raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF")
|
||||
@ -257,7 +257,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# The CPU has more supported types, and harnesses
|
||||
self.assertEqual("cpu", jtu.device_under_test())
|
||||
self.assertTrue(
|
||||
config.x64_enabled,
|
||||
config.enable_x64.value,
|
||||
"Documentation generation must be run with JAX_ENABLE_X64=1")
|
||||
|
||||
with open(
|
||||
@ -299,7 +299,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
y = np.int32(3)
|
||||
self.ConvertAndCompare(jnp.floor_divide, x, y)
|
||||
expected = jnp.floor_divide(x, y)
|
||||
if not config.jax2tf_default_native_serialization:
|
||||
if not config.jax2tf_default_native_serialization.value:
|
||||
# With native serialization TF1 seems to want to run the converted code
|
||||
# on the CPU even when the default backend is the TPU.
|
||||
# Try it with TF 1 as well (#5831)
|
||||
|
@ -30,8 +30,8 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
@ -47,7 +47,7 @@ import numpy as np
|
||||
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
# Must come after initializing the flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
@ -225,7 +225,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# Annotation count for the input
|
||||
count_in_P = 1 if in_shardings == "P" else 0
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
# With native serialization even unspecified in_shardings turn into replicated
|
||||
count_in_replicated = 1 if in_shardings in [None, "missing"] else 0
|
||||
else:
|
||||
@ -400,14 +400,14 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# Annotation count for the primal input and the grad output
|
||||
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
# With native serialization even unspecified shardings turn into replicated
|
||||
count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0
|
||||
else:
|
||||
count_in_replicated = self.GEQ(2) if in_shardings is None else 0
|
||||
# Annotation count for the contangent input
|
||||
count_out_P = self.GEQ(1) if out_shardings == "P" else 0
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
# With native serialization even unspecified shardings turn into replicated
|
||||
count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0
|
||||
else:
|
||||
@ -479,7 +479,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
"nested_pjit_sharded", "nested_pjit_replicated")
|
||||
])
|
||||
def test_pjit_eager_error(self, func="pjit_sharded"):
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
raise unittest.SkipTest("There is no error in eager mode for native serialization")
|
||||
|
||||
# Define some test functions
|
||||
|
@ -29,9 +29,9 @@ from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax._src import config
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
@ -177,16 +177,14 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
# We run the tests using the maximum version supported, even though
|
||||
# the default serialization version may be held back for a while to
|
||||
# ensure compatibility
|
||||
version = config.jax_serialization_version
|
||||
self.addCleanup(functools.partial(config.update,
|
||||
"jax_serialization_version", version))
|
||||
version = config.jax_serialization_version.value
|
||||
if self.use_max_serialization_version:
|
||||
# Use the largest supported by both export and tfxla.call_module
|
||||
version = min(export.maximum_supported_serialization_version,
|
||||
tfxla.call_module_maximum_supported_version())
|
||||
self.assertGreaterEqual(version,
|
||||
export.minimum_supported_serialization_version)
|
||||
config.update("jax_serialization_version", version)
|
||||
self.enter_context(config.jax_serialization_version(version))
|
||||
logging.info(
|
||||
"Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)",
|
||||
version,
|
||||
@ -203,7 +201,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
def to_numpy_dtype(dt):
|
||||
return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype
|
||||
|
||||
if not config.x64_enabled and canonicalize_dtypes:
|
||||
if not config.enable_x64.value and canonicalize_dtypes:
|
||||
self.assertEqual(
|
||||
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
|
||||
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
|
||||
@ -410,7 +408,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
# graph. We count the number of characters in the textual representation
|
||||
# of the constant.
|
||||
f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def()
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
# This way of finding constants may be brittle, if the constant representation
|
||||
# contains >. It seems tobe hex-encoded, so this may be safe.
|
||||
large_consts = [m for m in re.findall(r"dense<([^>]+)>", str(f_tf_graph)) if len(m) >= at_least]
|
||||
|
@ -26,12 +26,12 @@ import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
numpy_version = jtu.numpy_version()
|
||||
@ -705,7 +705,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
|
||||
@unittest.skipIf(not config.enable_x64.value, "test requires X64")
|
||||
@jtu.run_on_devices("cpu") # test is for CPU float64 precision
|
||||
def testPercentilePrecision(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8513
|
||||
|
@ -27,11 +27,11 @@ from jax import lax
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
import jax._src.scipy.sparse.linalg as sp_linalg
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
float_types = jtu.dtypes.floating
|
||||
@ -100,7 +100,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
preconditioner=[None, 'identity', 'exact', 'random'],
|
||||
)
|
||||
def test_cg_against_scipy(self, shape, dtype, preconditioner):
|
||||
if not config.x64_enabled:
|
||||
if not config.enable_x64.value:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -221,7 +221,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
)
|
||||
def test_bicgstab_against_scipy(
|
||||
self, shape, dtype, preconditioner):
|
||||
if not config.jax_enable_x64:
|
||||
if not config.enable_x64.value:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -326,7 +326,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
)
|
||||
def test_gmres_against_scipy(
|
||||
self, shape, dtype, preconditioner, solve_method):
|
||||
if not config.x64_enabled:
|
||||
if not config.enable_x64.value:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -437,7 +437,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
"""
|
||||
The Arnoldi decomposition within GMRES is correct.
|
||||
"""
|
||||
if not config.x64_enabled:
|
||||
if not config.enable_x64.value:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
|
||||
import scipy.stats
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import ad_checkpoint
|
||||
@ -33,8 +34,7 @@ from jax import random
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@ -143,7 +143,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
# TODO(mattjj): include log_softmax in these extra tests if/when we add a
|
||||
# custom_jvp rule for it (since otherwise it doesn't pass the numerical
|
||||
# checks below).
|
||||
if fn is nn.softmax and config.jax_softmax_custom_jvp:
|
||||
if fn is nn.softmax and config.softmax_custom_jvp.value:
|
||||
g_fun = lambda x: jnp.take(fn(x, where=m, initial=-jnp.inf),
|
||||
jnp.array([0, 2, 3]))
|
||||
jtu.check_grads(g_fun, (x,), order=2)
|
||||
@ -153,7 +153,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3)
|
||||
|
||||
def testSoftmaxGradResiduals(self):
|
||||
if not jax.config.jax_softmax_custom_jvp:
|
||||
if not config.softmax_custom_jvp.value:
|
||||
raise unittest.SkipTest("only applies when upgrade flag enabled")
|
||||
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
||||
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
||||
|
@ -30,6 +30,7 @@ import concurrent.futures
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax import stages
|
||||
@ -59,8 +60,7 @@ from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
@ -955,7 +955,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testWithCustomPRNGKey(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
if not config.enable_custom_prng.value:
|
||||
raise unittest.SkipTest("test requires jax_enable_custom_prng")
|
||||
key = prng.seed_with_impl(prng.rbg_prng_impl, 87)
|
||||
# Make sure this doesn't crash
|
||||
|
@ -22,8 +22,8 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
@ -40,7 +40,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _format_multiline(text):
|
||||
@ -197,7 +197,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
|
||||
if config.jax_enable_x64:
|
||||
if config.enable_x64.value:
|
||||
raise unittest.SkipTest("Test only needed when 64-bit mode disabled.")
|
||||
|
||||
@jax.jit
|
||||
|
@ -29,6 +29,7 @@ from jax import grad
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
@ -36,8 +37,7 @@ from jax import vmap
|
||||
|
||||
from jax._src import prng as prng_internal
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
complex_dtypes = jtu.dtypes.complex
|
||||
@ -63,7 +63,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
fail_prob = 0.003 if samples.dtype == jnp.bfloat16 else 0.01
|
||||
# TODO(frostig): This reads enable_custom_prng as a proxy for
|
||||
# whether RBG keys may be involved, but that's no longer exact.
|
||||
if config.jax_enable_custom_prng and samples.dtype == jnp.bfloat16:
|
||||
if config.enable_custom_prng.value and samples.dtype == jnp.bfloat16:
|
||||
return
|
||||
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
|
||||
|
||||
@ -382,7 +382,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=[np.float64], # NOTE: KS test fails with float32
|
||||
)
|
||||
def testBeta(self, a, b, dtype):
|
||||
if not config.x64_enabled:
|
||||
if not config.enable_x64.value:
|
||||
raise SkipTest("skip test except on X64")
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
|
||||
@ -824,7 +824,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
assert np.unique(keys, axis=0).shape[0] == 2
|
||||
|
||||
def testStaticShapeErrors(self):
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
raise SkipTest("test only relevant when jit enabled")
|
||||
|
||||
@jax.jit
|
||||
@ -991,7 +991,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def test_prng_jit_invariance(self, seed, type_):
|
||||
if type_ == "int" and seed == (1 << 64) - 1:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
|
||||
if not config.enable_x64.value and seed > np.iinfo(np.int32).max:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
|
||||
args_maker = lambda: [type_(seed)]
|
||||
|
@ -30,6 +30,7 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import random
|
||||
from jax import tree_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
@ -39,8 +40,7 @@ from jax.interpreters import xla
|
||||
from jax._src import random as jax_random
|
||||
from jax._src import prng as prng_internal
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
PRNG_IMPLS = list(prng_internal.prngs.items())
|
||||
@ -247,7 +247,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
finally:
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testRngRandomBits(self, make_key):
|
||||
# Test specific outputs to ensure consistent random values between JAX versions.
|
||||
@ -280,7 +280,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = random_bits(key, 64, (3,))
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
expected64 = np.array([3982329540505020460, 16822122385914693683,
|
||||
7882654074788531506], dtype=np.uint64)
|
||||
else:
|
||||
@ -315,11 +315,11 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = random_bits(key, 64, (3,))
|
||||
expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32')
|
||||
expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32')
|
||||
self.assertEqual(bits64.shape, (3,))
|
||||
self.assertEqual(bits64.dtype, expected_dtype)
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testRngRandomBitsViewProperty(self, make_key):
|
||||
# TODO: add 64-bit if it ever supports this property.
|
||||
@ -338,7 +338,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
@jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS)
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
|
||||
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
||||
def testRandomDistributionValues(self, case, make_key):
|
||||
"""
|
||||
@ -355,9 +355,9 @@ class PrngTest(jtu.JaxTestCase):
|
||||
* Considering adding a flag that reverts the new behavior, made
|
||||
available for a deprecation window's amount of time.
|
||||
"""
|
||||
if config.x64_enabled and case.on_x64 == OnX64.SKIP:
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("test produces different values when jax_enable_x64=True")
|
||||
if not config.x64_enabled and case.on_x64 == OnX64.ONLY:
|
||||
if not config.enable_x64.value:
|
||||
self.skipTest("test only valid when jax_enable_x64=True")
|
||||
with jax.default_prng_impl(case.prng_impl):
|
||||
func = getattr(random, case.name)
|
||||
@ -368,7 +368,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
actual = func(key, **case.params, shape=case.shape)
|
||||
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testPRNGValues(self, make_key):
|
||||
# Test to ensure consistent random values between JAX versions
|
||||
@ -376,7 +376,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
|
||||
dtypes.canonicalize_dtype(jnp.int_))
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8, dtype='int64'),
|
||||
np.array([[7, 2, 6],
|
||||
@ -407,7 +407,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
random.bits(make_key(0), (3, 4), np.dtype('float16'))
|
||||
|
||||
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
|
||||
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_threefry_split_fold_in_symmetry(self, make_key):
|
||||
with jax.default_prng_impl('threefry2x32'):
|
||||
@ -420,7 +420,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(f2, s2)
|
||||
self.assertArraysEqual(f3, s3)
|
||||
|
||||
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
|
||||
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
|
||||
# See https://github.com/google/jax/issues/7708
|
||||
@ -435,7 +435,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(f2, s2)
|
||||
self.assertArraysEqual(f3, s3)
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
|
||||
def test_loggamma_nan_corner_case(self):
|
||||
# regression test for https://github.com/google/jax/issues/17922
|
||||
# This particular key previously led to NaN output.
|
||||
@ -459,20 +459,20 @@ class PrngTest(jtu.JaxTestCase):
|
||||
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
|
||||
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
|
||||
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
|
||||
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]},
|
||||
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]},
|
||||
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
|
||||
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
|
||||
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]},
|
||||
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]},
|
||||
]
|
||||
for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS]
|
||||
])
|
||||
@ -482,7 +482,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
maker = lambda k: random.key_data(jax.jit(make_key)(k))
|
||||
else:
|
||||
maker = lambda k: random.key_data(make_key(k))
|
||||
if (jit and typ is int and not config.x64_enabled and
|
||||
if (jit and typ is int and not config.enable_x64.value and
|
||||
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
||||
# We expect an error to be raised.
|
||||
# NOTE: we check 'if jit' because some people rely on builtin int seeds
|
||||
@ -607,7 +607,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "Cannot interpret"):
|
||||
jnp.issubdtype(key, dtypes.prng_key)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
@skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag')
|
||||
def test_construction_upgrade_flag(self):
|
||||
key = random.PRNGKey(42)
|
||||
self.assertIsInstance(key, jax_random.PRNGKeyArray)
|
||||
@ -1059,7 +1059,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(jax.random.key_data(key1),
|
||||
jax.random.key_data(key2))
|
||||
|
||||
impl = config.jax_default_prng_impl
|
||||
impl = config.default_prng_impl.value
|
||||
key3 = jax.random.wrap_key_data(data, impl=impl)
|
||||
self.assertEqual(key1.dtype, key3.dtype)
|
||||
self.assertArraysEqual(jax.random.key_data(key1),
|
||||
|
@ -28,9 +28,9 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -43,7 +43,7 @@ import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -589,7 +589,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16,
|
||||
dtype=jnp.int32)
|
||||
|
||||
pspec = P('x') if config.jax_enable_custom_prng else P('x', None)
|
||||
pspec = P('x') if config.enable_custom_prng.value else P('x', None)
|
||||
g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec)
|
||||
_ = g(sharded_rng) # don't crash!
|
||||
|
||||
|
@ -24,8 +24,8 @@ import jax
|
||||
from jax import random
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax import config
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import tuple_insert
|
||||
@ -48,7 +48,7 @@ from jax._src.state.primitives import (get_p, swap_p, addupdate_p,
|
||||
from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
|
||||
AccumEffect, AbstractRef)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
@ -413,9 +413,9 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims,
|
||||
idx_bdims, out_bdim, op):
|
||||
|
||||
float_ = (jnp.dtype('float64') if jax.config.jax_enable_x64 else
|
||||
float_ = (jnp.dtype('float64') if config.enable_x64.value else
|
||||
jnp.dtype('float32'))
|
||||
int_ = (jnp.dtype('int64') if jax.config.jax_enable_x64 else
|
||||
int_ = (jnp.dtype('int64') if config.enable_x64.value else
|
||||
jnp.dtype('int32'))
|
||||
axis_size = 7
|
||||
out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b)
|
||||
|
Loading…
x
Reference in New Issue
Block a user