Merge pull request #18753 from jakevdp:warnings-tests

PiperOrigin-RevId: 586739682
This commit is contained in:
jax authors 2023-11-30 11:37:24 -08:00
commit 4de07b3f62
8 changed files with 18 additions and 67 deletions

View File

@ -1040,9 +1040,9 @@ class JaxTestCase(parameterized.TestCase):
@contextmanager
def assertNoWarnings(self):
with warnings.catch_warnings(record=True) as caught_warnings:
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
self.assertEmpty(caught_warnings)
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
rtol=None, atol=None, check_cache_misses=True):
@ -1124,9 +1124,9 @@ class BufferDonationTestCase(JaxTestCase):
@contextmanager
def ignore_warning(**kw):
def ignore_warning(*, message='', category=Warning, **kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", **kw)
warnings.filterwarnings("ignore", message=message, category=category, **kw)
yield
# -------------------- Mesh parametrization helpers --------------------

View File

@ -34,7 +34,6 @@ import sys
import types
from typing import Callable, NamedTuple, Optional
import unittest
import warnings
import weakref
from absl import logging
@ -394,16 +393,9 @@ class JitTest(jtu.BufferDonationTestCase):
y = jnp.array([1, 2], jnp.int32)
f = jit(lambda x, y: x.sum() + jnp.float32(y.sum()),
**{argnum_type: argnum_val})
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"):
f(x, y)
self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn(
"Some donated buffers were not usable:",
str(w[-1].message))
@parameterized.named_parameters(
("argnums", "donate_argnums", 0),
("argnames", "donate_argnames", 'x'),
@ -480,8 +472,7 @@ class JitTest(jtu.BufferDonationTestCase):
x = jnp.asarray([0, 1])
x_copy = jnp.array(x, copy=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with jtu.ignore_warning():
_test(x) # donation
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
@ -2879,10 +2870,9 @@ class APITest(jtu.JaxTestCase):
def test_dtype_from_builtin_types(self):
for dtype in [bool, int, float, complex]:
with warnings.catch_warnings(record=True) as caught_warnings:
with self.assertNoWarnings():
x = jnp.array(0, dtype=dtype)
self.assertEmpty(caught_warnings)
assert x.dtype == dtypes.canonicalize_dtype(dtype)
self.assertEqual(x.dtype, dtypes.canonicalize_dtype(dtype))
def test_dtype_warning(self):
# cf. issue #1230
@ -2890,24 +2880,10 @@ class APITest(jtu.JaxTestCase):
raise unittest.SkipTest("test only applies when x64 is disabled")
def check_warning(warn, nowarn):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
nowarn() # get rid of extra startup warning
prev_len = len(w)
nowarn()
assert len(w) == prev_len
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"):
warn()
assert len(w) > 0
msg = str(w[-1].message)
expected_prefix = "Explicitly requested dtype "
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
prev_len = len(w)
with self.assertNoWarnings():
nowarn()
assert len(w) == prev_len
check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
lambda: jnp.array([1, 2, 3], dtype="float32"))

View File

@ -14,7 +14,6 @@
import functools
import threading
import unittest
import warnings
from absl.testing import absltest
import jax
@ -535,8 +534,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with jtu.ignore_warning():
f(jnp.arange(jax.device_count())) # doesn't crash
def test_cant_jit_and_pmap_function_with_ordered_effects(self):

View File

@ -18,7 +18,6 @@ from functools import partial
import itertools
import typing
from typing import Any, Optional
import warnings
from absl.testing import absltest
from absl.testing import parameterized
@ -1524,10 +1523,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
def testIndexDtypeError(self):
# https://github.com/google/jax/issues/2795
jnp.array(1) # get rid of startup warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("error")
with self.assertNoWarnings():
jnp.zeros(5).at[::2].set(1)
self.assertLen(w, 0)
@jtu.sample_product(
[dict(idx=idx, idx_type=idx_type)

View File

@ -25,7 +25,6 @@ import platform
from typing import cast, Optional
import unittest
from unittest import SkipTest
import warnings
from absl.testing import absltest
from absl.testing import parameterized
@ -4564,8 +4563,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]])
# matrix directives
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
with jtu.ignore_warning(category=PendingDeprecationWarning):
self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]])
@ -4613,8 +4611,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]])
# matrix directives, avoid numpy deprecation warning
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
with jtu.ignore_warning(category=PendingDeprecationWarning):
self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]])
@ -5497,8 +5494,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning)
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"):
_ = func(*args)
except TypeError:
pass

View File

@ -14,7 +14,6 @@
import functools
import math
import warnings
from absl.testing import absltest
from absl.testing import parameterized
from absl import flags
@ -929,14 +928,9 @@ class MemoriesComputationTest(jtu.BufferDonationTestCase):
def f(x):
return x * 2
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"):
f(inp)
self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn("Some donated buffers were not usable:", str(w[-1].message))
lowered_text = f.lower(inp).as_text("hlo")
self.assertNotIn("input_output_alias", lowered_text)
self.assertNotDeleted(inp)

View File

@ -24,7 +24,6 @@ import re
from typing import Optional, cast
import unittest
from unittest import SkipTest
import warnings
import weakref
import numpy as np
@ -1848,14 +1847,9 @@ class PythonPmapTest(jtu.JaxTestCase):
def foo(x): return x
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "The jitted function foo includes a pmap"):
jit(self.pmap(foo))(jnp.arange(device_count))
self.assertGreaterEqual(len(w), 1)
self.assertIn("The jitted function foo includes a pmap",
str(w[-1].message))
def testJitOfPmapOutputSharding(self):
device_count = jax.device_count()

View File

@ -355,12 +355,8 @@ class GetBackendTest(jtu.JaxTestCase):
xb.get_backend("none")
def cpu_fallback_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "No GPU/TPU found, falling back to CPU"):
xb.get_backend()
self.assertLen(w, 1)
msg = str(w[-1].message)
self.assertIn("No GPU/TPU found, falling back to CPU", msg)
def test_jax_platforms_flag(self):
self._register_factory("platform_A", 20, assert_used_at_most_once=True)