diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 43b78c7c9..370492520 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1040,9 +1040,9 @@ class JaxTestCase(parameterized.TestCase): @contextmanager def assertNoWarnings(self): - with warnings.catch_warnings(record=True) as caught_warnings: + with warnings.catch_warnings(): + warnings.simplefilter("error") yield - self.assertEmpty(caught_warnings) def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None, rtol=None, atol=None, check_cache_misses=True): @@ -1124,9 +1124,9 @@ class BufferDonationTestCase(JaxTestCase): @contextmanager -def ignore_warning(**kw): +def ignore_warning(*, message='', category=Warning, **kw): with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) + warnings.filterwarnings("ignore", message=message, category=category, **kw) yield # -------------------- Mesh parametrization helpers -------------------- diff --git a/tests/api_test.py b/tests/api_test.py index 9d02229f8..05e730432 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -34,7 +34,6 @@ import sys import types from typing import Callable, NamedTuple, Optional import unittest -import warnings import weakref from absl import logging @@ -394,16 +393,9 @@ class JitTest(jtu.BufferDonationTestCase): y = jnp.array([1, 2], jnp.int32) f = jit(lambda x, y: x.sum() + jnp.float32(y.sum()), **{argnum_type: argnum_val}) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"): f(x, y) - self.assertLen(w, 1) - self.assertTrue(issubclass(w[-1].category, UserWarning)) - self.assertIn( - "Some donated buffers were not usable:", - str(w[-1].message)) - @parameterized.named_parameters( ("argnums", "donate_argnums", 0), ("argnames", "donate_argnames", 'x'), @@ -480,8 +472,7 @@ class JitTest(jtu.BufferDonationTestCase): x = jnp.asarray([0, 1]) x_copy = jnp.array(x, copy=True) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with jtu.ignore_warning(): _test(x) # donation # Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. @@ -2879,10 +2870,9 @@ class APITest(jtu.JaxTestCase): def test_dtype_from_builtin_types(self): for dtype in [bool, int, float, complex]: - with warnings.catch_warnings(record=True) as caught_warnings: + with self.assertNoWarnings(): x = jnp.array(0, dtype=dtype) - self.assertEmpty(caught_warnings) - assert x.dtype == dtypes.canonicalize_dtype(dtype) + self.assertEqual(x.dtype, dtypes.canonicalize_dtype(dtype)) def test_dtype_warning(self): # cf. issue #1230 @@ -2890,24 +2880,10 @@ class APITest(jtu.JaxTestCase): raise unittest.SkipTest("test only applies when x64 is disabled") def check_warning(warn, nowarn): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - nowarn() # get rid of extra startup warning - - prev_len = len(w) - nowarn() - assert len(w) == prev_len - + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"): warn() - assert len(w) > 0 - msg = str(w[-1].message) - expected_prefix = "Explicitly requested dtype " - self.assertEqual(expected_prefix, msg[:len(expected_prefix)]) - - prev_len = len(w) + with self.assertNoWarnings(): nowarn() - assert len(w) == prev_len check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"), lambda: jnp.array([1, 2, 3], dtype="float32")) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 161665c15..255822a26 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -14,7 +14,6 @@ import functools import threading import unittest -import warnings from absl.testing import absltest import jax @@ -535,8 +534,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def f(x): effect_p.bind(effect=bar_effect) return x + 1 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with jtu.ignore_warning(): f(jnp.arange(jax.device_count())) # doesn't crash def test_cant_jit_and_pmap_function_with_ordered_effects(self): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 404b34f7d..c77cdd35a 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -18,7 +18,6 @@ from functools import partial import itertools import typing from typing import Any, Optional -import warnings from absl.testing import absltest from absl.testing import parameterized @@ -1524,10 +1523,8 @@ class IndexedUpdateTest(jtu.JaxTestCase): def testIndexDtypeError(self): # https://github.com/google/jax/issues/2795 jnp.array(1) # get rid of startup warning - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("error") + with self.assertNoWarnings(): jnp.zeros(5).at[::2].set(1) - self.assertLen(w, 0) @jtu.sample_product( [dict(idx=idx, idx_type=idx_type) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 51cf262eb..9ce9a1cf6 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -25,7 +25,6 @@ import platform from typing import cast, Optional import unittest from unittest import SkipTest -import warnings from absl.testing import absltest from absl.testing import parameterized @@ -4564,8 +4563,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) # matrix directives - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + with jtu.ignore_warning(category=PendingDeprecationWarning): self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) @@ -4613,8 +4611,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) # matrix directives, avoid numpy deprecation warning - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + with jtu.ignore_warning(category=PendingDeprecationWarning): self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) @@ -5497,8 +5494,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]: for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning) + with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"): _ = func(*args) except TypeError: pass diff --git a/tests/memories_test.py b/tests/memories_test.py index 1d7de4436..7f0be2043 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -14,7 +14,6 @@ import functools import math -import warnings from absl.testing import absltest from absl.testing import parameterized from absl import flags @@ -929,14 +928,9 @@ class MemoriesComputationTest(jtu.BufferDonationTestCase): def f(x): return x * 2 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"): f(inp) - self.assertLen(w, 1) - self.assertTrue(issubclass(w[-1].category, UserWarning)) - self.assertIn("Some donated buffers were not usable:", str(w[-1].message)) - lowered_text = f.lower(inp).as_text("hlo") self.assertNotIn("input_output_alias", lowered_text) self.assertNotDeleted(inp) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index d0c42ee42..ef2816ff0 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -24,7 +24,6 @@ import re from typing import Optional, cast import unittest from unittest import SkipTest -import warnings import weakref import numpy as np @@ -1848,14 +1847,9 @@ class PythonPmapTest(jtu.JaxTestCase): def foo(x): return x - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "The jitted function foo includes a pmap"): jit(self.pmap(foo))(jnp.arange(device_count)) - self.assertGreaterEqual(len(w), 1) - self.assertIn("The jitted function foo includes a pmap", - str(w[-1].message)) - def testJitOfPmapOutputSharding(self): device_count = jax.device_count() diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index b5522e2e5..a3d78f1c2 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -355,12 +355,8 @@ class GetBackendTest(jtu.JaxTestCase): xb.get_backend("none") def cpu_fallback_warning(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "No GPU/TPU found, falling back to CPU"): xb.get_backend() - self.assertLen(w, 1) - msg = str(w[-1].message) - self.assertIn("No GPU/TPU found, falling back to CPU", msg) def test_jax_platforms_flag(self): self._register_factory("platform_A", 20, assert_used_at_most_once=True)