Make jax_enable_x64 a thread-local value.

This commit is contained in:
Jake VanderPlas 2021-02-04 09:48:22 -08:00
parent 10cff5f2bf
commit 2fd682ef2a
25 changed files with 126 additions and 81 deletions

View File

@ -22,7 +22,7 @@ from typing import Any, Callable
import numpy as np
import jax
from jax.config import FLAGS
from jax.config import config
partial = functools.partial
@ -195,7 +195,7 @@ def cache(max_size=4096):
if jax.core.debug_state.check_leaks:
return f(*args, **kwargs)
else:
return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)
return cached(bool(config.x64_enabled), *args, **kwargs)
wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
@ -209,7 +209,7 @@ def memoize(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return memoized(bool(FLAGS.jax_enable_x64), *args, **kwargs)
return memoized(bool(config.x64_enabled), *args, **kwargs)
wrapper.cache_clear = memoized.cache_clear
wrapper.cache_info = memoized.cache_info

View File

@ -110,6 +110,7 @@ def _check_callable(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")
# TODO(jakevdp): merge this with _thread_local_state in jax.config
class _ThreadLocalState(threading.local):
def __init__(self):
@ -355,7 +356,7 @@ def _cpp_jit(
functions decorated with jax.jit), so we delay inspecting the value
of the jax_enable_x64 flag until JIT time.
"""
return FLAGS.jax_enable_x64
return config.x64_enabled
def get_jax_disable_jit_flag():
"""Returns the value of the `jax_disable_jit` flag.
@ -376,7 +377,7 @@ def _cpp_jit(
@api_boundary
def f_jitted(*args, **kwargs):
context = (getattr(core.thread_local_state.trace_state.trace_stack,
'dynamic', None), bool(FLAGS.jax_enable_x64))
'dynamic', None), config.x64_enabled)
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(context, *args, **kwargs)

View File

@ -14,6 +14,8 @@
import os
import sys
import threading
from typing import Optional
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
@ -40,7 +42,15 @@ def int_env(varname: str, default: int) -> int:
return int(os.getenv(varname, default))
class _ThreadLocalState(threading.local):
def __init__(self):
self.enable_x64: Optional[bool] = None
class Config:
_thread_local_state = _ThreadLocalState()
def __init__(self):
self.values = {}
self.meta = {}
@ -121,7 +131,6 @@ class Config:
if not FLAGS.jax_omnistaging:
self.disable_omnistaging()
def register_omnistaging_disabler(self, disabler):
if self.omnistaging_enabled:
self._omnistaging_disablers.append(disabler)
@ -138,6 +147,16 @@ class Config:
disabler()
self.omnistaging_enabled = False
@property
def x64_enabled(self):
if self._thread_local_state.enable_x64 is None:
self._thread_local_state.enable_x64 = bool(self.read('jax_enable_x64'))
return self._thread_local_state.enable_x64
# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
self._thread_local_state.enable_x64 = bool(state)
class NameSpace(object):
def __init__(self, getter):

View File

@ -28,7 +28,7 @@ from typing import Dict
import numpy as np
from ._src import util
from .config import flags
from .config import flags, config
from .lib import xla_client
from ._src import traceback_util
@ -67,7 +67,7 @@ _dtype_to_32bit_dtype = {
@util.memoize
def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
if isinstance(dtype, str) and dtype == "bfloat16":
dtype = bfloat16
try:
@ -75,7 +75,7 @@ def canonicalize_dtype(dtype):
except TypeError as e:
raise TypeError(f'dtype {dtype!r} not understood') from e
if FLAGS.jax_enable_x64:
if config.x64_enabled:
return dtype
else:
return _dtype_to_32bit_dtype.get(dtype, dtype)

View File

@ -33,7 +33,6 @@ from jax.config import config
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
# Import after parsing flags
from jax.experimental.jax2tf.tests import primitive_harness
@ -133,7 +132,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
raise unittest.SkipTest("Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation")
# The CPU/GPU have more supported types than TPU.
self.assertEqual("cpu", jtu.device_under_test(), "The documentation can be generated only on CPU")
self.assertTrue(FLAGS.jax_enable_x64, "The documentation must be generated with JAX_ENABLE_X64=1")
self.assertTrue(config.x64_enabled, "The documentation must be generated with JAX_ENABLE_X64=1")
with open(os.path.join(os.path.dirname(__file__),
'../g3doc/jax_primitives_coverage.md.template')) as f:

View File

@ -70,7 +70,6 @@ from jax.interpreters import xla
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
# Import after parsing flags
from jax.experimental.jax2tf.tests import tf_test_util
@ -206,7 +205,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# The CPU has more supported types, and harnesses
self.assertEqual("cpu", jtu.device_under_test())
self.assertTrue(
FLAGS.jax_enable_x64,
config.x64_enabled,
"Documentation generation must be run with JAX_ENABLE_X64=1")
with open(

View File

@ -51,7 +51,7 @@ class StaxTest(tf_test_util.JaxToTfTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_res_net(self):
if config.FLAGS.jax_enable_x64:
if config.x64_enabled:
raise unittest.SkipTest("ResNet test fails on JAX when X64 is enabled")
key = jax.random.PRNGKey(0)
shape = (224, 224, 3, 1)

View File

@ -99,7 +99,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
def to_numpy_dtype(dt):
return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype
if not config.FLAGS.jax_enable_x64 and canonicalize_dtypes:
if not config.x64_enabled and canonicalize_dtypes:
self.assertEqual(
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))

View File

@ -18,7 +18,7 @@
"""
from contextlib import contextmanager
from jax import config
from jax.config import config
@contextmanager
def enable_x64():
@ -36,12 +36,12 @@ def enable_x64():
--------
jax.experimental.disable_x64 : temporarily disable X64 mode.
"""
_x64_state = config.FLAGS.jax_enable_x64
config.update('jax_enable_x64', True)
_x64_state = config.x64_enabled
config._set_x64_enabled(True)
try:
yield
finally:
config.update('jax_enable_x64', _x64_state)
config._set_x64_enabled(_x64_state)
@contextmanager
def disable_x64():
@ -59,9 +59,9 @@ def disable_x64():
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
_x64_state = config.FLAGS.jax_enable_x64
config.update('jax_enable_x64', False)
_x64_state = config.x64_enabled
config._set_x64_enabled(False)
try:
yield
finally:
config.update('jax_enable_x64', _x64_state)
config._set_x64_enabled(_x64_state)

View File

@ -282,7 +282,7 @@ def host_count(backend: Optional[str] = None) -> int:
@util.memoize
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))

View File

@ -73,7 +73,7 @@ from .tree_util import tree_map
from ._src import traceback_util
from .config import FLAGS
from .config import config
traceback_util.register_exclusion(__file__)
@ -249,9 +249,9 @@ def cache(call: Callable):
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if core.debug_state.check_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args, bool(FLAGS.jax_enable_x64))
key = (_copy_main_traces(fun.transforms), fun.params, args, config.x64_enabled)
else:
key = (fun.transforms, fun.params, args, bool(FLAGS.jax_enable_x64))
key = (fun.transforms, fun.params, args, config.x64_enabled)
result = cache.get(key, None)
if result is not None:
ans, stores = result

View File

@ -32,7 +32,7 @@ from . import api
from . import core
from . import dtypes as _dtypes
from . import lax
from .config import flags, bool_env
from .config import flags, bool_env, config
from ._src.util import partial, prod
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
@ -172,7 +172,7 @@ def check_close(xs, ys, atol=None, rtol=None):
def _check_dtypes_match(xs, ys):
def _assert_dtypes_match(x, y):
if FLAGS.jax_enable_x64:
if config.x64_enabled:
assert _dtype(x) == _dtype(y)
else:
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
@ -373,7 +373,7 @@ def supported_dtypes():
np.uint8, np.uint16, np.uint32, np.uint64,
_dtypes.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128}
if not FLAGS.jax_enable_x64:
if not config.x64_enabled:
types -= {np.uint64, np.int64, np.float64, np.complex128}
return types
@ -817,7 +817,7 @@ class JaxTestCase(parameterized.TestCase):
self.assertDtypesMatch(x, y)
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
if not FLAGS.jax_enable_x64 and canonicalize_dtypes:
if not config.x64_enabled and canonicalize_dtypes:
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)),
_dtypes.canonicalize_dtype(_dtype(y)))
else:

View File

@ -193,7 +193,7 @@ class CPPJitTest(jtu.JaxTestCase):
assert len(side) == 2 # but should still cache
f(one, two, z=np.zeros(3)) # doesn't crash
if FLAGS.jax_enable_x64:
if config.x64_enabled:
# In the above call, three is of a new type (int64), thus it should
# trigger a new compilation.
assert len(side) == 3
@ -1582,7 +1582,7 @@ class APITest(jtu.JaxTestCase):
def test_dtype_warning(self):
# cf. issue #1230
if FLAGS.jax_enable_x64:
if config.x64_enabled:
raise unittest.SkipTest("test only applies when x64 is disabled")
def check_warning(warn, nowarn):

View File

@ -25,7 +25,6 @@ from jax import test_util as jtu
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
try:
import torch
@ -95,7 +94,7 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in dlpack_dtypes))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testTensorFlowToJax(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
raise self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.device_under_test() == "gpu" and
@ -118,7 +117,7 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in dlpack_dtypes))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testJaxToTensorFlow(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.device_under_test() == "gpu" and
@ -142,7 +141,7 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testTorchToJax(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
@ -160,7 +159,7 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testJaxToTorch(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)

View File

@ -22,11 +22,9 @@ from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental.doubledouble import doubledouble, _DoubleDouble
from jax.config import config, flags
from jax.config import config
config.parse_flags_with_absl()
FLAGS = flags.FLAGS
class DoubleDoubleTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
@ -73,7 +71,7 @@ class DoubleDoubleTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
double_op1 = doubledouble(op1)
args = 1E20 * rng(shape, dtype), rng(shape, dtype)
check_dtypes = not FLAGS.jax_enable_x64
check_dtypes = not config.x64_enabled
self.assertAllClose(double_op1(*args), op2(*args), check_dtypes=check_dtypes)

View File

@ -30,7 +30,6 @@ from jax.interpreters import xla
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
bool_dtypes = [np.dtype('bool')]
@ -73,7 +72,7 @@ class DtypesTest(jtu.JaxTestCase):
True: _EXPECTED_CANONICALIZE_X64,
False: _EXPECTED_CANONICALIZE_X32,
}
for in_dtype, expected_dtype in expected[FLAGS.jax_enable_x64].items():
for in_dtype, expected_dtype in expected[config.x64_enabled].items():
self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
@parameterized.named_parameters(
@ -229,7 +228,7 @@ class TestPromotionTables(jtu.JaxTestCase):
# Note: * here refers to weakly-typed values
typecodes = \
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*']
if FLAGS.jax_enable_x64:
if config.x64_enabled:
expected = [
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i8','f8','c8'],
['u1','u1','u2','u4','u8','i2','i2','i4','i8','bf','f2','f4','f8','c4','c8','u1','f8','c8'],

View File

@ -25,10 +25,8 @@ from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
from jax.config import flags
config.parse_flags_with_absl()
FLAGS = flags.FLAGS
float_dtypes = jtu.dtypes.floating
inexact_dtypes = jtu.dtypes.inexact
@ -112,7 +110,7 @@ class FftTest(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker)
# Test gradient for differentiable types.
if (FLAGS.jax_enable_x64 and
if (config.x64_enabled and
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
# TODO(skye): can we be more precise?
tol = 0.15

View File

@ -22,11 +22,10 @@ from jax import dtypes
from jax import lib as jaxlib
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import flags
from jax.config import config
from jax.lib import version
import numpy as np
FLAGS = flags.FLAGS
# It covers all JAX numpy types types except bfloat16 and numpy array.
# TODO(jblespiau): Add support for float0 and bfloat16 in the C++ path.
@ -37,7 +36,7 @@ _SCALAR_NUMPY_TYPES = [
def _cpp_device_put(value, device):
return jaxlib.jax_jit.device_put(value, FLAGS.jax_enable_x64, device)
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
class JaxJitTest(parameterized.TestCase):
@ -167,7 +166,7 @@ class JaxJitTest(parameterized.TestCase):
"old jaxlib version")
def test_arg_signature_of_value(self):
"""Tests the C++ code-path."""
jax_enable_x64 = FLAGS.jax_enable_x64
jax_enable_x64 = config.x64_enabled
# 1. Numpy scalar types
for dtype in _SCALAR_NUMPY_TYPES:

View File

@ -20,7 +20,6 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class JaxprStatsTest(jtu.JaxTestCase):
@ -59,7 +58,7 @@ class JaxprStatsTest(jtu.JaxTestCase):
hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)
t = '64' if FLAGS.jax_enable_x64 else '32'
t = '64' if config.x64_enabled else '32'
shapes = [
f'add :: float{t}[]',
f'sin :: float{t}[]',

View File

@ -34,7 +34,6 @@ from jax._src import util
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
# We disable the whitespace continuation check in this file because otherwise it
# makes the test name formatting unwieldy.
@ -54,7 +53,7 @@ IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"])
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
# TODO(mattjj,dougalm): add higher-order check
default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
default_tol = 1e-6 if config.x64_enabled else 1e-2
atol = atol or default_tol
rtol = rtol or default_tol
eps = eps or default_tol

View File

@ -660,7 +660,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for rec in JAX_BITWISE_OP_RECORDS))
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
rng = rng_factory(self.rng())
if not FLAGS.jax_enable_x64 and any(
if not config.x64_enabled and any(
jnp.iinfo(dtype).bits == 64 for dtype in dtypes):
self.skipTest("x64 types are disabled by jax_enable_x64")
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
@ -684,7 +684,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np.issubdtype(shift_dtype, np.signedinteger)
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
promoting_to_64 = has_32 and signed_mix
if promoting_to_64 and not FLAGS.jax_enable_x64:
if promoting_to_64 and not config.x64_enabled:
self.skipTest("np.right_shift/left_shift promoting to int64"
"differs from jnp in 32 bit mode.")
@ -2317,7 +2317,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory):
# integer types are converted to float64 in numpy's implementation
if (x1_dtype not in [jnp.bfloat16, np.float16, np.float32]
and not FLAGS.jax_enable_x64):
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
x1_rng = x1_rng_factory(self.rng())
x2_rng = x2_rng_factory(self.rng())
@ -2344,7 +2344,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testFrexp(self, shape, dtype, rng_factory):
# integer types are converted to float64 in numpy's implementation
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
and not FLAGS.jax_enable_x64):
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
rng = rng_factory(self.rng())
np_fun = lambda x: np.frexp(x)
@ -3326,7 +3326,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if jtu.device_under_test() == 'tpu':
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
if not FLAGS.jax_enable_x64:
if not config.x64_enabled:
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_fullrange(self.rng())
@ -4391,7 +4391,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
endpoint, base, dtype):
if (dtype in int_dtypes and
jtu.device_under_test() in ("gpu", "tpu") and
not FLAGS.jax_enable_x64):
not config.x64_enabled):
raise unittest.SkipTest("GPUx32 truncated exponentiation"
" doesn't exactly match other platforms.")
rng = jtu.rand_default(self.rng())

View File

@ -90,7 +90,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']))
def test_cg_against_scipy(self, shape, dtype, preconditioner):
if not config.FLAGS.jax_enable_x64:
if not config.x64_enabled:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
@ -208,7 +208,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for solve_method in ['incremental', 'batched']))
def test_gmres_against_scipy(
self, shape, dtype, preconditioner, solve_method):
if not config.FLAGS.jax_enable_x64:
if not config.x64_enabled:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())
@ -325,7 +325,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
"""
The Arnoldi decomposition within GMRES is correct.
"""
if not config.FLAGS.jax_enable_x64:
if not config.x64_enabled:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())

View File

@ -41,7 +41,6 @@ from jax._src.lax.lax import _device_put_raw
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
### lax tests
@ -194,7 +193,7 @@ class LaxTest(jtu.JaxTestCase):
for dtype in rec.dtypes)
for rec in LAX_OPS))
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
if (not FLAGS.jax_enable_x64 and op_name == "nextafter"
if (not config.x64_enabled and op_name == "nextafter"
and dtype == np.float64):
raise SkipTest("64-bit mode disabled")
rng = rng_factory(self.rng())
@ -897,7 +896,7 @@ class LaxTest(jtu.JaxTestCase):
(np.int8, np.int64), (np.int16, np.int16), (np.int16, np.int32), (np.int16, np.int64),
(np.int32, np.int32), (np.int32, np.int64), (np.int64, np.int64)]))
def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type):
if (not FLAGS.jax_enable_x64 and
if (not config.x64_enabled and
(dtype == np.float64 or preferred_element_type == np.float64
or dtype == np.int64 or preferred_element_type == np.int64)):
raise SkipTest("64-bit mode disabled")

View File

@ -37,7 +37,6 @@ import jax._src.random
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
@ -156,7 +155,7 @@ class LaxRandomTest(jtu.JaxTestCase):
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = jax._src.random._random_bits(key, 64, (3,))
if FLAGS.jax_enable_x64:
if config.x64_enabled:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
else:
@ -397,7 +396,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for b in [0.2, 5.]
for dtype in [np.float64])) # NOTE: KS test fails with float32
def testBeta(self, a, b, dtype):
if not FLAGS.jax_enable_x64:
if not config.x64_enabled:
raise SkipTest("skip test except on X64")
key = random.PRNGKey(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
@ -728,7 +727,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testIssue756(self):
key = random.PRNGKey(0)
w = random.normal(key, ())
if FLAGS.jax_enable_x64:
if config.x64_enabled:
self.assertEqual(np.result_type(w), np.float64)
else:
self.assertEqual(np.result_type(w), np.float32)
@ -751,7 +750,7 @@ class LaxRandomTest(jtu.JaxTestCase):
# Test to ensure consistent random values between JAX versions
k = random.PRNGKey(0)
if FLAGS.jax_enable_x64:
if config.x64_enabled:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8),
np.array([[7, 2, 6],
@ -917,20 +916,20 @@ class LaxRandomTest(jtu.JaxTestCase):
{"seed": 2, "type": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "type": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "type": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if FLAGS.jax_enable_x64 else [0, 4294967295]},
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if FLAGS.jax_enable_x64 else [0, 4294967295]},
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if FLAGS.jax_enable_x64 else [0, 4294967293]},
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if FLAGS.jax_enable_x64 else [0, 4294967293]},
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if FLAGS.jax_enable_x64 else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if FLAGS.jax_enable_x64 else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if FLAGS.jax_enable_x64 else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if FLAGS.jax_enable_x64 else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
]
))
def test_prng_seeds_and_keys(self, seed, type, jit, key):

View File

@ -12,13 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import time
from absl.testing import parameterized
from jax import api, config, lax, partial, random
from jax import api, lax, partial, random
from jax.config import config, FLAGS
from jax.experimental import enable_x64, disable_x64
import jax.numpy as jnp
import jax.test_util as jtu
config.parse_flags_with_absl()
def _maybe_jit(jit_type, func, *args, **kwargs):
if jit_type == "python":
return api._python_jit(func, *args, **kwargs)
@ -84,3 +92,33 @@ class X64ContextTests(jtu.JaxTestCase):
with disable_x64():
self.assertArraysEqual(count_to(10), jnp.float32(10), check_dtypes=True)
def test_thread_safety(self):
def func_x32():
with disable_x64():
time.sleep(0.1)
return jnp.arange(10).dtype
def func_x64():
with enable_x64():
time.sleep(0.1)
return jnp.arange(10).dtype
with concurrent.futures.ThreadPoolExecutor() as executor:
x32 = executor.submit(func_x32)
x64 = executor.submit(func_x64)
self.assertEqual(x64.result(), jnp.int64)
self.assertEqual(x32.result(), jnp.int32)
def test_jit_cache(self):
# TODO(jakevdp): enable this test when CPP jit cache is fixed.
if FLAGS.experimental_cpp_jit:
self.skipTest("Known failure due to https://github.com/google/jax/issues/5532")
f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
with disable_x64():
for _ in range(2):
f()
with enable_x64():
for _ in range(2):
f()