mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make jax_enable_x64 a thread-local value.
This commit is contained in:
parent
10cff5f2bf
commit
2fd682ef2a
@ -22,7 +22,7 @@ from typing import Any, Callable
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.config import FLAGS
|
||||
from jax.config import config
|
||||
|
||||
partial = functools.partial
|
||||
|
||||
@ -195,7 +195,7 @@ def cache(max_size=4096):
|
||||
if jax.core.debug_state.check_leaks:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)
|
||||
return cached(bool(config.x64_enabled), *args, **kwargs)
|
||||
|
||||
wrapper.cache_clear = cached.cache_clear
|
||||
wrapper.cache_info = cached.cache_info
|
||||
@ -209,7 +209,7 @@ def memoize(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return memoized(bool(FLAGS.jax_enable_x64), *args, **kwargs)
|
||||
return memoized(bool(config.x64_enabled), *args, **kwargs)
|
||||
|
||||
wrapper.cache_clear = memoized.cache_clear
|
||||
wrapper.cache_info = memoized.cache_info
|
||||
|
@ -110,6 +110,7 @@ def _check_callable(fun):
|
||||
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
||||
|
||||
|
||||
# TODO(jakevdp): merge this with _thread_local_state in jax.config
|
||||
class _ThreadLocalState(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
@ -355,7 +356,7 @@ def _cpp_jit(
|
||||
functions decorated with jax.jit), so we delay inspecting the value
|
||||
of the jax_enable_x64 flag until JIT time.
|
||||
"""
|
||||
return FLAGS.jax_enable_x64
|
||||
return config.x64_enabled
|
||||
|
||||
def get_jax_disable_jit_flag():
|
||||
"""Returns the value of the `jax_disable_jit` flag.
|
||||
@ -376,7 +377,7 @@ def _cpp_jit(
|
||||
@api_boundary
|
||||
def f_jitted(*args, **kwargs):
|
||||
context = (getattr(core.thread_local_state.trace_state.trace_stack,
|
||||
'dynamic', None), bool(FLAGS.jax_enable_x64))
|
||||
'dynamic', None), config.x64_enabled)
|
||||
# TODO(jblespiau): Move this to C++.
|
||||
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
|
||||
device_arrays = cpp_jitted_f(context, *args, **kwargs)
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
@ -40,7 +42,15 @@ def int_env(varname: str, default: int) -> int:
|
||||
return int(os.getenv(varname, default))
|
||||
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
self.enable_x64: Optional[bool] = None
|
||||
|
||||
|
||||
class Config:
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def __init__(self):
|
||||
self.values = {}
|
||||
self.meta = {}
|
||||
@ -121,7 +131,6 @@ class Config:
|
||||
if not FLAGS.jax_omnistaging:
|
||||
self.disable_omnistaging()
|
||||
|
||||
|
||||
def register_omnistaging_disabler(self, disabler):
|
||||
if self.omnistaging_enabled:
|
||||
self._omnistaging_disablers.append(disabler)
|
||||
@ -138,6 +147,16 @@ class Config:
|
||||
disabler()
|
||||
self.omnistaging_enabled = False
|
||||
|
||||
@property
|
||||
def x64_enabled(self):
|
||||
if self._thread_local_state.enable_x64 is None:
|
||||
self._thread_local_state.enable_x64 = bool(self.read('jax_enable_x64'))
|
||||
return self._thread_local_state.enable_x64
|
||||
|
||||
# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
|
||||
def _set_x64_enabled(self, state):
|
||||
self._thread_local_state.enable_x64 = bool(state)
|
||||
|
||||
|
||||
class NameSpace(object):
|
||||
def __init__(self, getter):
|
||||
|
@ -28,7 +28,7 @@ from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
from ._src import util
|
||||
from .config import flags
|
||||
from .config import flags, config
|
||||
from .lib import xla_client
|
||||
|
||||
from ._src import traceback_util
|
||||
@ -67,7 +67,7 @@ _dtype_to_32bit_dtype = {
|
||||
|
||||
@util.memoize
|
||||
def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
if isinstance(dtype, str) and dtype == "bfloat16":
|
||||
dtype = bfloat16
|
||||
try:
|
||||
@ -75,7 +75,7 @@ def canonicalize_dtype(dtype):
|
||||
except TypeError as e:
|
||||
raise TypeError(f'dtype {dtype!r} not understood') from e
|
||||
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
return dtype
|
||||
else:
|
||||
return _dtype_to_32bit_dtype.get(dtype, dtype)
|
||||
|
@ -33,7 +33,6 @@ from jax.config import config
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
@ -133,7 +132,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation")
|
||||
# The CPU/GPU have more supported types than TPU.
|
||||
self.assertEqual("cpu", jtu.device_under_test(), "The documentation can be generated only on CPU")
|
||||
self.assertTrue(FLAGS.jax_enable_x64, "The documentation must be generated with JAX_ENABLE_X64=1")
|
||||
self.assertTrue(config.x64_enabled, "The documentation must be generated with JAX_ENABLE_X64=1")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__),
|
||||
'../g3doc/jax_primitives_coverage.md.template')) as f:
|
||||
|
@ -70,7 +70,6 @@ from jax.interpreters import xla
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
@ -206,7 +205,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# The CPU has more supported types, and harnesses
|
||||
self.assertEqual("cpu", jtu.device_under_test())
|
||||
self.assertTrue(
|
||||
FLAGS.jax_enable_x64,
|
||||
config.x64_enabled,
|
||||
"Documentation generation must be run with JAX_ENABLE_X64=1")
|
||||
|
||||
with open(
|
||||
|
@ -51,7 +51,7 @@ class StaxTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_res_net(self):
|
||||
if config.FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
raise unittest.SkipTest("ResNet test fails on JAX when X64 is enabled")
|
||||
key = jax.random.PRNGKey(0)
|
||||
shape = (224, 224, 3, 1)
|
||||
|
@ -99,7 +99,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
def to_numpy_dtype(dt):
|
||||
return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype
|
||||
|
||||
if not config.FLAGS.jax_enable_x64 and canonicalize_dtypes:
|
||||
if not config.x64_enabled and canonicalize_dtypes:
|
||||
self.assertEqual(
|
||||
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
|
||||
dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
|
||||
|
@ -18,7 +18,7 @@
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from jax import config
|
||||
from jax.config import config
|
||||
|
||||
@contextmanager
|
||||
def enable_x64():
|
||||
@ -36,12 +36,12 @@ def enable_x64():
|
||||
--------
|
||||
jax.experimental.disable_x64 : temporarily disable X64 mode.
|
||||
"""
|
||||
_x64_state = config.FLAGS.jax_enable_x64
|
||||
config.update('jax_enable_x64', True)
|
||||
_x64_state = config.x64_enabled
|
||||
config._set_x64_enabled(True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
config.update('jax_enable_x64', _x64_state)
|
||||
config._set_x64_enabled(_x64_state)
|
||||
|
||||
@contextmanager
|
||||
def disable_x64():
|
||||
@ -59,9 +59,9 @@ def disable_x64():
|
||||
--------
|
||||
jax.experimental.enable_x64 : temporarily enable X64 mode.
|
||||
"""
|
||||
_x64_state = config.FLAGS.jax_enable_x64
|
||||
config.update('jax_enable_x64', False)
|
||||
_x64_state = config.x64_enabled
|
||||
config._set_x64_enabled(False)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
config.update('jax_enable_x64', _x64_state)
|
||||
config._set_x64_enabled(_x64_state)
|
||||
|
@ -282,7 +282,7 @@ def host_count(backend: Optional[str] = None) -> int:
|
||||
|
||||
@util.memoize
|
||||
def dtype_to_etype(dtype):
|
||||
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
|
||||
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
|
||||
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
|
||||
|
@ -73,7 +73,7 @@ from .tree_util import tree_map
|
||||
|
||||
from ._src import traceback_util
|
||||
|
||||
from .config import FLAGS
|
||||
from .config import config
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -249,9 +249,9 @@ def cache(call: Callable):
|
||||
def memoized_fun(fun: WrappedFun, *args):
|
||||
cache = fun_caches.setdefault(fun.f, {})
|
||||
if core.debug_state.check_leaks:
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, args, bool(FLAGS.jax_enable_x64))
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, args, config.x64_enabled)
|
||||
else:
|
||||
key = (fun.transforms, fun.params, args, bool(FLAGS.jax_enable_x64))
|
||||
key = (fun.transforms, fun.params, args, config.x64_enabled)
|
||||
result = cache.get(key, None)
|
||||
if result is not None:
|
||||
ans, stores = result
|
||||
|
@ -32,7 +32,7 @@ from . import api
|
||||
from . import core
|
||||
from . import dtypes as _dtypes
|
||||
from . import lax
|
||||
from .config import flags, bool_env
|
||||
from .config import flags, bool_env, config
|
||||
from ._src.util import partial, prod
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from .lib import xla_bridge
|
||||
@ -172,7 +172,7 @@ def check_close(xs, ys, atol=None, rtol=None):
|
||||
|
||||
def _check_dtypes_match(xs, ys):
|
||||
def _assert_dtypes_match(x, y):
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
assert _dtype(x) == _dtype(y)
|
||||
else:
|
||||
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
|
||||
@ -373,7 +373,7 @@ def supported_dtypes():
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
_dtypes.bfloat16, np.float16, np.float32, np.float64,
|
||||
np.complex64, np.complex128}
|
||||
if not FLAGS.jax_enable_x64:
|
||||
if not config.x64_enabled:
|
||||
types -= {np.uint64, np.int64, np.float64, np.complex128}
|
||||
return types
|
||||
|
||||
@ -817,7 +817,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
self.assertDtypesMatch(x, y)
|
||||
|
||||
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
|
||||
if not FLAGS.jax_enable_x64 and canonicalize_dtypes:
|
||||
if not config.x64_enabled and canonicalize_dtypes:
|
||||
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)),
|
||||
_dtypes.canonicalize_dtype(_dtype(y)))
|
||||
else:
|
||||
|
@ -193,7 +193,7 @@ class CPPJitTest(jtu.JaxTestCase):
|
||||
assert len(side) == 2 # but should still cache
|
||||
|
||||
f(one, two, z=np.zeros(3)) # doesn't crash
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
# In the above call, three is of a new type (int64), thus it should
|
||||
# trigger a new compilation.
|
||||
assert len(side) == 3
|
||||
@ -1582,7 +1582,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
raise unittest.SkipTest("test only applies when x64 is disabled")
|
||||
|
||||
def check_warning(warn, nowarn):
|
||||
|
@ -25,7 +25,6 @@ from jax import test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
try:
|
||||
import torch
|
||||
@ -95,7 +94,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testTensorFlowToJax(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
raise self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.device_under_test() == "gpu" and
|
||||
@ -118,7 +117,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testJaxToTensorFlow(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.device_under_test() == "gpu" and
|
||||
@ -142,7 +141,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJax(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
@ -160,7 +159,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
|
@ -22,11 +22,9 @@ from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental.doubledouble import doubledouble, _DoubleDouble
|
||||
|
||||
from jax.config import config, flags
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
class DoubleDoubleTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}".format(
|
||||
@ -73,7 +71,7 @@ class DoubleDoubleTest(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
double_op1 = doubledouble(op1)
|
||||
args = 1E20 * rng(shape, dtype), rng(shape, dtype)
|
||||
check_dtypes = not FLAGS.jax_enable_x64
|
||||
check_dtypes = not config.x64_enabled
|
||||
|
||||
self.assertAllClose(double_op1(*args), op2(*args), check_dtypes=check_dtypes)
|
||||
|
||||
|
@ -30,7 +30,6 @@ from jax.interpreters import xla
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
bool_dtypes = [np.dtype('bool')]
|
||||
|
||||
@ -73,7 +72,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
True: _EXPECTED_CANONICALIZE_X64,
|
||||
False: _EXPECTED_CANONICALIZE_X32,
|
||||
}
|
||||
for in_dtype, expected_dtype in expected[FLAGS.jax_enable_x64].items():
|
||||
for in_dtype, expected_dtype in expected[config.x64_enabled].items():
|
||||
self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -229,7 +228,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
# Note: * here refers to weakly-typed values
|
||||
typecodes = \
|
||||
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*']
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
expected = [
|
||||
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i8','f8','c8'],
|
||||
['u1','u1','u2','u4','u8','i2','i2','i4','i8','bf','f2','f4','f8','c4','c8','u1','f8','c8'],
|
||||
|
@ -25,10 +25,8 @@ from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax.config import flags
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
float_dtypes = jtu.dtypes.floating
|
||||
inexact_dtypes = jtu.dtypes.inexact
|
||||
@ -112,7 +110,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
# Test gradient for differentiable types.
|
||||
if (FLAGS.jax_enable_x64 and
|
||||
if (config.x64_enabled and
|
||||
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
|
||||
# TODO(skye): can we be more precise?
|
||||
tol = 0.15
|
||||
|
@ -22,11 +22,10 @@ from jax import dtypes
|
||||
from jax import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import flags
|
||||
from jax.config import config
|
||||
from jax.lib import version
|
||||
import numpy as np
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# It covers all JAX numpy types types except bfloat16 and numpy array.
|
||||
# TODO(jblespiau): Add support for float0 and bfloat16 in the C++ path.
|
||||
@ -37,7 +36,7 @@ _SCALAR_NUMPY_TYPES = [
|
||||
|
||||
|
||||
def _cpp_device_put(value, device):
|
||||
return jaxlib.jax_jit.device_put(value, FLAGS.jax_enable_x64, device)
|
||||
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
|
||||
|
||||
|
||||
class JaxJitTest(parameterized.TestCase):
|
||||
@ -167,7 +166,7 @@ class JaxJitTest(parameterized.TestCase):
|
||||
"old jaxlib version")
|
||||
def test_arg_signature_of_value(self):
|
||||
"""Tests the C++ code-path."""
|
||||
jax_enable_x64 = FLAGS.jax_enable_x64
|
||||
jax_enable_x64 = config.x64_enabled
|
||||
|
||||
# 1. Numpy scalar types
|
||||
for dtype in _SCALAR_NUMPY_TYPES:
|
||||
|
@ -20,7 +20,6 @@ from jax.config import config
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
class JaxprStatsTest(jtu.JaxTestCase):
|
||||
@ -59,7 +58,7 @@ class JaxprStatsTest(jtu.JaxTestCase):
|
||||
|
||||
hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)
|
||||
|
||||
t = '64' if FLAGS.jax_enable_x64 else '32'
|
||||
t = '64' if config.x64_enabled else '32'
|
||||
shapes = [
|
||||
f'add :: float{t}[]',
|
||||
f'sin :: float{t}[]',
|
||||
|
@ -34,7 +34,6 @@ from jax._src import util
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
# We disable the whitespace continuation check in this file because otherwise it
|
||||
# makes the test name formatting unwieldy.
|
||||
@ -54,7 +53,7 @@ IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"])
|
||||
|
||||
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
|
||||
# TODO(mattjj,dougalm): add higher-order check
|
||||
default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
|
||||
default_tol = 1e-6 if config.x64_enabled else 1e-2
|
||||
atol = atol or default_tol
|
||||
rtol = rtol or default_tol
|
||||
eps = eps or default_tol
|
||||
|
@ -660,7 +660,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for rec in JAX_BITWISE_OP_RECORDS))
|
||||
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
|
||||
rng = rng_factory(self.rng())
|
||||
if not FLAGS.jax_enable_x64 and any(
|
||||
if not config.x64_enabled and any(
|
||||
jnp.iinfo(dtype).bits == 64 for dtype in dtypes):
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
@ -684,7 +684,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np.issubdtype(shift_dtype, np.signedinteger)
|
||||
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
|
||||
promoting_to_64 = has_32 and signed_mix
|
||||
if promoting_to_64 and not FLAGS.jax_enable_x64:
|
||||
if promoting_to_64 and not config.x64_enabled:
|
||||
self.skipTest("np.right_shift/left_shift promoting to int64"
|
||||
"differs from jnp in 32 bit mode.")
|
||||
|
||||
@ -2317,7 +2317,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory):
|
||||
# integer types are converted to float64 in numpy's implementation
|
||||
if (x1_dtype not in [jnp.bfloat16, np.float16, np.float32]
|
||||
and not FLAGS.jax_enable_x64):
|
||||
and not config.x64_enabled):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
x1_rng = x1_rng_factory(self.rng())
|
||||
x2_rng = x2_rng_factory(self.rng())
|
||||
@ -2344,7 +2344,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testFrexp(self, shape, dtype, rng_factory):
|
||||
# integer types are converted to float64 in numpy's implementation
|
||||
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
|
||||
and not FLAGS.jax_enable_x64):
|
||||
and not config.x64_enabled):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
rng = rng_factory(self.rng())
|
||||
np_fun = lambda x: np.frexp(x)
|
||||
@ -3326,7 +3326,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
|
||||
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
|
||||
if not FLAGS.jax_enable_x64:
|
||||
if not config.x64_enabled:
|
||||
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_fullrange(self.rng())
|
||||
@ -4391,7 +4391,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
endpoint, base, dtype):
|
||||
if (dtype in int_dtypes and
|
||||
jtu.device_under_test() in ("gpu", "tpu") and
|
||||
not FLAGS.jax_enable_x64):
|
||||
not config.x64_enabled):
|
||||
raise unittest.SkipTest("GPUx32 truncated exponentiation"
|
||||
" doesn't exactly match other platforms.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -90,7 +90,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
for dtype in [np.float64, np.complex128]
|
||||
for preconditioner in [None, 'identity', 'exact', 'random']))
|
||||
def test_cg_against_scipy(self, shape, dtype, preconditioner):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
if not config.x64_enabled:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -208,7 +208,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
for solve_method in ['incremental', 'batched']))
|
||||
def test_gmres_against_scipy(
|
||||
self, shape, dtype, preconditioner, solve_method):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
if not config.x64_enabled:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -325,7 +325,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
"""
|
||||
The Arnoldi decomposition within GMRES is correct.
|
||||
"""
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
if not config.x64_enabled:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -41,7 +41,6 @@ from jax._src.lax.lax import _device_put_raw
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
### lax tests
|
||||
@ -194,7 +193,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_OPS))
|
||||
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
||||
if (not FLAGS.jax_enable_x64 and op_name == "nextafter"
|
||||
if (not config.x64_enabled and op_name == "nextafter"
|
||||
and dtype == np.float64):
|
||||
raise SkipTest("64-bit mode disabled")
|
||||
rng = rng_factory(self.rng())
|
||||
@ -897,7 +896,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
(np.int8, np.int64), (np.int16, np.int16), (np.int16, np.int32), (np.int16, np.int64),
|
||||
(np.int32, np.int32), (np.int32, np.int64), (np.int64, np.int64)]))
|
||||
def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type):
|
||||
if (not FLAGS.jax_enable_x64 and
|
||||
if (not config.x64_enabled and
|
||||
(dtype == np.float64 or preferred_element_type == np.float64
|
||||
or dtype == np.int64 or preferred_element_type == np.int64)):
|
||||
raise SkipTest("64-bit mode disabled")
|
||||
|
@ -37,7 +37,6 @@ import jax._src.random
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
complex_dtypes = jtu.dtypes.complex
|
||||
@ -156,7 +155,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = jax._src.random._random_bits(key, 64, (3,))
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
expected64 = np.array([3982329540505020460, 16822122385914693683,
|
||||
7882654074788531506], dtype=np.uint64)
|
||||
else:
|
||||
@ -397,7 +396,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for b in [0.2, 5.]
|
||||
for dtype in [np.float64])) # NOTE: KS test fails with float32
|
||||
def testBeta(self, a, b, dtype):
|
||||
if not FLAGS.jax_enable_x64:
|
||||
if not config.x64_enabled:
|
||||
raise SkipTest("skip test except on X64")
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
|
||||
@ -728,7 +727,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testIssue756(self):
|
||||
key = random.PRNGKey(0)
|
||||
w = random.normal(key, ())
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
self.assertEqual(np.result_type(w), np.float64)
|
||||
else:
|
||||
self.assertEqual(np.result_type(w), np.float32)
|
||||
@ -751,7 +750,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# Test to ensure consistent random values between JAX versions
|
||||
k = random.PRNGKey(0)
|
||||
|
||||
if FLAGS.jax_enable_x64:
|
||||
if config.x64_enabled:
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8),
|
||||
np.array([[7, 2, 6],
|
||||
@ -917,20 +916,20 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"seed": 2, "type": np.uint32, "jit": False, "key": [0, 2]},
|
||||
{"seed": 3, "type": np.int64, "jit": True, "key": [0, 3]},
|
||||
{"seed": 3, "type": np.int64, "jit": False, "key": [0, 3]},
|
||||
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if FLAGS.jax_enable_x64 else [0, 4294967295]},
|
||||
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if FLAGS.jax_enable_x64 else [0, 4294967295]},
|
||||
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294]},
|
||||
{"seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294]},
|
||||
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if FLAGS.jax_enable_x64 else [0, 4294967293]},
|
||||
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if FLAGS.jax_enable_x64 else [0, 4294967293]},
|
||||
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if FLAGS.jax_enable_x64 else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if FLAGS.jax_enable_x64 else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if FLAGS.jax_enable_x64 else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if FLAGS.jax_enable_x64 else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
]
|
||||
))
|
||||
def test_prng_seeds_and_keys(self, seed, type, jit, key):
|
||||
|
@ -12,13 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import concurrent.futures
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import api, config, lax, partial, random
|
||||
from jax import api, lax, partial, random
|
||||
from jax.config import config, FLAGS
|
||||
from jax.experimental import enable_x64, disable_x64
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _maybe_jit(jit_type, func, *args, **kwargs):
|
||||
if jit_type == "python":
|
||||
return api._python_jit(func, *args, **kwargs)
|
||||
@ -84,3 +92,33 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
|
||||
with disable_x64():
|
||||
self.assertArraysEqual(count_to(10), jnp.float32(10), check_dtypes=True)
|
||||
|
||||
def test_thread_safety(self):
|
||||
def func_x32():
|
||||
with disable_x64():
|
||||
time.sleep(0.1)
|
||||
return jnp.arange(10).dtype
|
||||
|
||||
def func_x64():
|
||||
with enable_x64():
|
||||
time.sleep(0.1)
|
||||
return jnp.arange(10).dtype
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
x32 = executor.submit(func_x32)
|
||||
x64 = executor.submit(func_x64)
|
||||
self.assertEqual(x64.result(), jnp.int64)
|
||||
self.assertEqual(x32.result(), jnp.int32)
|
||||
|
||||
def test_jit_cache(self):
|
||||
# TODO(jakevdp): enable this test when CPP jit cache is fixed.
|
||||
if FLAGS.experimental_cpp_jit:
|
||||
self.skipTest("Known failure due to https://github.com/google/jax/issues/5532")
|
||||
|
||||
f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
|
||||
with disable_x64():
|
||||
for _ in range(2):
|
||||
f()
|
||||
with enable_x64():
|
||||
for _ in range(2):
|
||||
f()
|
||||
|
Loading…
x
Reference in New Issue
Block a user