mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure. This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
This commit is contained in:
parent
640cb009f1
commit
b06779b177
@ -120,6 +120,7 @@ py_library(
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"_src/test_util.py",
|
||||
"_src/test_warning_util.py",
|
||||
],
|
||||
visibility = [
|
||||
":internal",
|
||||
|
@ -35,7 +35,6 @@ import threading
|
||||
import time
|
||||
from typing import Any, TextIO
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -49,6 +48,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import lib as _jaxlib
|
||||
from jax._src import monitoring
|
||||
from jax._src import test_warning_util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import util
|
||||
from jax._src import mesh as mesh_lib
|
||||
@ -118,7 +118,7 @@ HYPOTHESIS_PROFILE = config.string_flag(
|
||||
)
|
||||
|
||||
TEST_NUM_THREADS = config.int_flag(
|
||||
'jax_test_num_threads', 0,
|
||||
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
|
||||
help='Number of threads to use for running tests. 0 means run everything '
|
||||
'in the main thread. Using > 1 thread is experimental.'
|
||||
)
|
||||
@ -1076,7 +1076,7 @@ class ThreadSafeTestResult:
|
||||
with self.lock:
|
||||
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
|
||||
# override how it gets the time.
|
||||
time_getter = self.test_result.time_getter
|
||||
time_getter = getattr(self.test_result, "time_getter", None)
|
||||
try:
|
||||
self.test_result.time_getter = lambda: self.start_time
|
||||
self.test_result.startTest(test)
|
||||
@ -1085,7 +1085,8 @@ class ThreadSafeTestResult:
|
||||
self.test_result.time_getter = lambda: stop_time
|
||||
self.test_result.stopTest(test)
|
||||
finally:
|
||||
self.test_result.time_getter = time_getter
|
||||
if time_getter is not None:
|
||||
self.test_result.time_getter = time_getter
|
||||
|
||||
def addSuccess(self, test: unittest.TestCase):
|
||||
self.actions.append(lambda: self.test_result.addSuccess(test))
|
||||
@ -1120,6 +1121,8 @@ class JaxTestSuite(unittest.TestSuite):
|
||||
if TEST_NUM_THREADS.value <= 0:
|
||||
return super().run(result)
|
||||
|
||||
test_warning_util.install_threadsafe_warning_handlers()
|
||||
|
||||
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
|
||||
lock = threading.Lock()
|
||||
futures = []
|
||||
@ -1368,12 +1371,45 @@ class JaxTestCase(parameterized.TestCase):
|
||||
self.assertMultiLineEqual(expected_clean, what_clean,
|
||||
msg=f"Found\n{what}\nExpecting\n{expected}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assertNoWarnings(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
with test_warning_util.raise_on_warnings():
|
||||
yield
|
||||
|
||||
# We replace assertWarns and assertWarnsRegex with functions that use the
|
||||
# thread-safe warning utilities. Unlike the unittest versions these only
|
||||
# function as context managers.
|
||||
@contextmanager
|
||||
def assertWarns(self, warning, *, msg=None):
|
||||
with test_warning_util.record_warnings() as ws:
|
||||
yield
|
||||
for w in ws:
|
||||
if not isinstance(w.message, warning):
|
||||
continue
|
||||
if msg is not None and msg not in str(w.message):
|
||||
continue
|
||||
return
|
||||
self.fail(f"Expected warning not found {warning}:'{msg}', got "
|
||||
f"{ws}")
|
||||
|
||||
@contextmanager
|
||||
def assertWarnsRegex(self, warning, regex):
|
||||
if regex is not None:
|
||||
regex = re.compile(regex)
|
||||
|
||||
with test_warning_util.record_warnings() as ws:
|
||||
yield
|
||||
for w in ws:
|
||||
if not isinstance(w.message, warning):
|
||||
continue
|
||||
if regex is not None and not regex.search(str(w.message)):
|
||||
continue
|
||||
return
|
||||
self.fail(f"Expected warning not found {warning}:'{regex}', got "
|
||||
f"{ws}")
|
||||
|
||||
|
||||
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
|
||||
rtol=None, atol=None, check_cache_misses=True):
|
||||
"""Helper method for running JAX compilation and allclose assertions."""
|
||||
@ -1449,11 +1485,7 @@ class BufferDonationTestCase(JaxTestCase):
|
||||
self.assertFalse(x.is_deleted())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ignore_warning(*, message='', category=Warning, **kw):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message=message, category=category, **kw)
|
||||
yield
|
||||
ignore_warning = test_warning_util.ignore_warning
|
||||
|
||||
# -------------------- Mesh parametrization helpers --------------------
|
||||
|
||||
@ -1768,9 +1800,8 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
||||
logtiny = finfo.minexp / prec_dps_ratio
|
||||
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
with ignore_warning(category=RuntimeWarning):
|
||||
# Silence RuntimeWarning: overflow encountered in cast
|
||||
warnings.simplefilter("ignore")
|
||||
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
||||
half_line = -half_neg_line[::-1]
|
||||
axis_points[-size - 1:-1] = half_line
|
||||
|
132
jax/_src/test_warning_util.py
Normal file
132
jax/_src/test_warning_util.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Thread-safe utilities for catching and testing for warnings.
|
||||
#
|
||||
# The Python warnings module, at least as of Python 3.13, is not thread-safe.
|
||||
# The catch_warnings() feature is inherently racy, see
|
||||
# https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe
|
||||
#
|
||||
# This module offers a thread-safe way to catch and record warnings. We install
|
||||
# a custom showwarning hook with the Python warning module, and then rely on
|
||||
# the CPython warnings module to call our show warning function. We then use it
|
||||
# to create our own thread-safe warning filtering utilities.
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
|
||||
class _WarningContext(threading.local):
|
||||
"Thread-local state that contains a list of warning handlers."
|
||||
|
||||
def __init__(self):
|
||||
self.handlers = []
|
||||
|
||||
|
||||
_context = _WarningContext()
|
||||
|
||||
|
||||
# Callback that applies the handlers in reverse order. If no handler matches,
|
||||
# we raise an error.
|
||||
def _showwarning(message, category, filename, lineno, file=None, line=None):
|
||||
for handler in reversed(_context.handlers):
|
||||
if handler(message, category, filename, lineno, file, line):
|
||||
return
|
||||
raise category(message)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def raise_on_warnings():
|
||||
"Context manager that raises an exception if a warning is raised."
|
||||
if warnings.showwarning is not _showwarning:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
yield
|
||||
return
|
||||
|
||||
def handler(message, category, filename, lineno, file=None, line=None):
|
||||
raise category(message)
|
||||
|
||||
_context.handlers.append(handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_context.handlers.pop()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def record_warnings():
|
||||
"Context manager that yields a list of warnings that are raised."
|
||||
if warnings.showwarning is not _showwarning:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
yield w
|
||||
return
|
||||
|
||||
log = []
|
||||
|
||||
def handler(message, category, filename, lineno, file=None, line=None):
|
||||
log.append(warnings.WarningMessage(message, category, filename, lineno, file, line))
|
||||
return True
|
||||
|
||||
_context.handlers.append(handler)
|
||||
try:
|
||||
yield log
|
||||
finally:
|
||||
_context.handlers.pop()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_warning(*, message: str | None = None, category: type = Warning):
|
||||
"Context manager that ignores any matching warnings."
|
||||
if warnings.showwarning is not _showwarning:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="" if message is None else message, category=category)
|
||||
yield
|
||||
return
|
||||
|
||||
if message:
|
||||
message_re = re.compile(message)
|
||||
else:
|
||||
message_re = None
|
||||
|
||||
category_cls = category
|
||||
|
||||
def handler(message, category, filename, lineno, file=None, line=None):
|
||||
text = str(message) if isinstance(message, Warning) else message
|
||||
if (message_re is None or message_re.match(text)) and issubclass(
|
||||
category, category_cls
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
_context.handlers.append(handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_context.handlers.pop()
|
||||
|
||||
|
||||
def install_threadsafe_warning_handlers():
|
||||
# Hook the showwarning method. The warnings module explicitly notes that
|
||||
# this is a function that users may replace.
|
||||
warnings.showwarning = _showwarning
|
||||
|
||||
# Set the warnings module to always display warnings. We hook into it by
|
||||
# overriding the "showwarning" method, so it's important that all warnings
|
||||
# are "shown" by the usual mechanism.
|
||||
warnings.simplefilter("always")
|
@ -1153,6 +1153,14 @@ jax_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "warnings_util_test",
|
||||
srcs = ["warnings_util_test.py"],
|
||||
deps = [
|
||||
"//jax:test_util",
|
||||
] + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "xla_bridge_test",
|
||||
srcs = ["xla_bridge_test.py"],
|
||||
|
@ -23,7 +23,6 @@ import platform
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest import SkipTest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -39,6 +38,7 @@ from jax._src import distributed
|
||||
from jax._src import monitoring
|
||||
from jax._src import path as pathlib
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import test_warning_util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.compilation_cache_interface import CacheInterface
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -232,21 +232,20 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
with (
|
||||
config.raise_persistent_cache_errors(False),
|
||||
mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put,
|
||||
warnings.catch_warnings(record=True) as w,
|
||||
test_warning_util.record_warnings() as w,
|
||||
):
|
||||
warnings.simplefilter("always")
|
||||
mock_put.side_effect = RuntimeError("test error")
|
||||
self.assertEqual(f(2).item(), 4)
|
||||
if len(w) != 1:
|
||||
print("Warnings:", [str(w_) for w_ in w], flush=True)
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn(
|
||||
(
|
||||
"Error writing persistent compilation cache entry "
|
||||
"for 'jit__lambda_': RuntimeError: test error"
|
||||
),
|
||||
str(w[0].message),
|
||||
)
|
||||
if len(w) != 1:
|
||||
print("Warnings:", [str(w_) for w_ in w], flush=True)
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn(
|
||||
(
|
||||
"Error writing persistent compilation cache entry "
|
||||
"for 'jit__lambda_': RuntimeError: test error"
|
||||
),
|
||||
str(w[0].message),
|
||||
)
|
||||
|
||||
def test_cache_read_warning(self):
|
||||
f = jit(lambda x: x * x)
|
||||
@ -255,23 +254,22 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
with (
|
||||
config.raise_persistent_cache_errors(False),
|
||||
mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get,
|
||||
warnings.catch_warnings(record=True) as w,
|
||||
test_warning_util.record_warnings() as w,
|
||||
):
|
||||
warnings.simplefilter("always")
|
||||
mock_get.side_effect = RuntimeError("test error")
|
||||
# Calling assertEqual with the jitted f will generate two PJIT
|
||||
# executables: Equal and the lambda function itself.
|
||||
self.assertEqual(f(2).item(), 4)
|
||||
if len(w) != 1:
|
||||
print("Warnings:", [str(w_) for w_ in w], flush=True)
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn(
|
||||
(
|
||||
"Error reading persistent compilation cache entry "
|
||||
"for 'jit__lambda_': RuntimeError: test error"
|
||||
),
|
||||
str(w[0].message),
|
||||
)
|
||||
if len(w) != 1:
|
||||
print("Warnings:", [str(w_) for w_ in w], flush=True)
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn(
|
||||
(
|
||||
"Error reading persistent compilation cache entry "
|
||||
"for 'jit__lambda_': RuntimeError: test error"
|
||||
),
|
||||
str(w[0].message),
|
||||
)
|
||||
|
||||
def test_min_entry_size(self):
|
||||
with (
|
||||
|
@ -12,18 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax._src import deprecations
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import test_warning_util
|
||||
from jax._src.internal_test_util import deprecation_module as m
|
||||
|
||||
class DeprecationTest(absltest.TestCase):
|
||||
|
||||
def testModuleDeprecation(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
with test_warning_util.raise_on_warnings():
|
||||
self.assertEqual(m.x, 42)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "Please use x"):
|
||||
|
@ -212,7 +212,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="mean of empty slice.*")
|
||||
message="Mean of empty slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="overflow encountered.*")
|
||||
def np_fun(x):
|
||||
|
@ -5581,8 +5581,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.ones(2) + 3 # don't want to raise for scalars
|
||||
|
||||
with jax.numpy_rank_promotion('warn'):
|
||||
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
|
||||
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"Following NumPy automatic rank promotion for add on shapes "
|
||||
r"\(2,\) \(1, 2\).*"
|
||||
):
|
||||
jnp.ones(2) + jnp.ones((1, 2))
|
||||
jnp.ones(2) + 3 # don't want to warn for scalars
|
||||
|
||||
@unittest.skip("Test fails on CI, perhaps due to JIT caching")
|
||||
|
86
tests/warnings_util_test.py
Normal file
86
tests/warnings_util_test.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_warning_util
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class WarningsUtilTest(jtu.JaxTestCase):
|
||||
|
||||
@test_warning_util.raise_on_warnings()
|
||||
def test_warning_raises(self):
|
||||
with self.assertRaises(UserWarning, msg="hello"):
|
||||
warnings.warn("hello", category=UserWarning)
|
||||
|
||||
with self.assertRaises(DeprecationWarning, msg="hello"):
|
||||
warnings.warn("hello", category=DeprecationWarning)
|
||||
|
||||
@test_warning_util.raise_on_warnings()
|
||||
def test_ignore_warning(self):
|
||||
with test_warning_util.ignore_warning(message="h.*o"):
|
||||
warnings.warn("hello", category=UserWarning)
|
||||
|
||||
with self.assertRaises(UserWarning, msg="hello"):
|
||||
with test_warning_util.ignore_warning(message="h.*o"):
|
||||
warnings.warn("goodbye", category=UserWarning)
|
||||
|
||||
with test_warning_util.ignore_warning(category=UserWarning):
|
||||
warnings.warn("hello", category=UserWarning)
|
||||
|
||||
with self.assertRaises(UserWarning, msg="hello"):
|
||||
with test_warning_util.ignore_warning(category=DeprecationWarning):
|
||||
warnings.warn("goodbye", category=UserWarning)
|
||||
|
||||
def test_record_warning(self):
|
||||
with test_warning_util.record_warnings() as w:
|
||||
warnings.warn("hello", category=UserWarning)
|
||||
warnings.warn("goodbye", category=DeprecationWarning)
|
||||
self.assertLen(w, 2)
|
||||
self.assertIs(w[0].category, UserWarning)
|
||||
self.assertIn("hello", str(w[0].message))
|
||||
self.assertIs(w[1].category, DeprecationWarning)
|
||||
self.assertIn("goodbye", str(w[1].message))
|
||||
|
||||
def test_record_warning_nested(self):
|
||||
with test_warning_util.record_warnings() as w:
|
||||
warnings.warn("aa", category=UserWarning)
|
||||
with test_warning_util.record_warnings() as v:
|
||||
warnings.warn("bb", category=UserWarning)
|
||||
warnings.warn("cc", category=DeprecationWarning)
|
||||
self.assertLen(w, 2)
|
||||
self.assertIs(w[0].category, UserWarning)
|
||||
self.assertIn("aa", str(w[0].message))
|
||||
self.assertIs(w[1].category, DeprecationWarning)
|
||||
self.assertIn("cc", str(w[1].message))
|
||||
self.assertLen(v, 1)
|
||||
self.assertIs(v[0].category, UserWarning)
|
||||
self.assertIn("bb", str(v[0].message))
|
||||
|
||||
|
||||
def test_raises_warning(self):
|
||||
with self.assertRaises(UserWarning, msg="hello"):
|
||||
with test_warning_util.ignore_warning():
|
||||
with test_warning_util.raise_on_warnings():
|
||||
warnings.warn("hello", category=UserWarning)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -14,8 +14,6 @@
|
||||
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
@ -126,31 +124,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
||||
xb.local_devices(backend="foo")
|
||||
|
||||
def test_timer_tpu_warning(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
def _mock_tpu_client_with_options(library_path=None, options=None):
|
||||
time_to_wait = 5
|
||||
start = time.time()
|
||||
while not w:
|
||||
if time.time() - start > time_to_wait:
|
||||
raise ValueError(
|
||||
"This test should not hang for more than "
|
||||
f"{time_to_wait} seconds.")
|
||||
time.sleep(0.1)
|
||||
|
||||
self.assertLen(w, 1)
|
||||
msg = str(w[-1].message)
|
||||
self.assertIn("Did you run your code on all TPU hosts?", msg)
|
||||
|
||||
def _mock_tpu_client(library_path=None):
|
||||
_mock_tpu_client_with_options(library_path=library_path, options=None)
|
||||
|
||||
with mock.patch.object(xc, "make_tpu_client",
|
||||
side_effect=_mock_tpu_client_with_options):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
|
||||
def test_register_plugin(self):
|
||||
with self.assertLogs(level="WARNING") as log_output:
|
||||
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
||||
|
Loading…
x
Reference in New Issue
Block a user