mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #18753 from jakevdp:warnings-tests
PiperOrigin-RevId: 586739682
This commit is contained in:
commit
4de07b3f62
@ -1040,9 +1040,9 @@ class JaxTestCase(parameterized.TestCase):
|
||||
|
||||
@contextmanager
|
||||
def assertNoWarnings(self):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
yield
|
||||
self.assertEmpty(caught_warnings)
|
||||
|
||||
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
|
||||
rtol=None, atol=None, check_cache_misses=True):
|
||||
@ -1124,9 +1124,9 @@ class BufferDonationTestCase(JaxTestCase):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ignore_warning(**kw):
|
||||
def ignore_warning(*, message='', category=Warning, **kw):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", **kw)
|
||||
warnings.filterwarnings("ignore", message=message, category=category, **kw)
|
||||
yield
|
||||
|
||||
# -------------------- Mesh parametrization helpers --------------------
|
||||
|
@ -34,7 +34,6 @@ import sys
|
||||
import types
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from absl import logging
|
||||
@ -394,16 +393,9 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
y = jnp.array([1, 2], jnp.int32)
|
||||
f = jit(lambda x, y: x.sum() + jnp.float32(y.sum()),
|
||||
**{argnum_type: argnum_val})
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"):
|
||||
f(x, y)
|
||||
|
||||
self.assertLen(w, 1)
|
||||
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||
self.assertIn(
|
||||
"Some donated buffers were not usable:",
|
||||
str(w[-1].message))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("argnums", "donate_argnums", 0),
|
||||
("argnames", "donate_argnames", 'x'),
|
||||
@ -480,8 +472,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
x = jnp.asarray([0, 1])
|
||||
x_copy = jnp.array(x, copy=True)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
with jtu.ignore_warning():
|
||||
_test(x) # donation
|
||||
|
||||
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
|
||||
@ -2879,10 +2870,9 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_dtype_from_builtin_types(self):
|
||||
for dtype in [bool, int, float, complex]:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with self.assertNoWarnings():
|
||||
x = jnp.array(0, dtype=dtype)
|
||||
self.assertEmpty(caught_warnings)
|
||||
assert x.dtype == dtypes.canonicalize_dtype(dtype)
|
||||
self.assertEqual(x.dtype, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
@ -2890,24 +2880,10 @@ class APITest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("test only applies when x64 is disabled")
|
||||
|
||||
def check_warning(warn, nowarn):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
nowarn() # get rid of extra startup warning
|
||||
|
||||
prev_len = len(w)
|
||||
nowarn()
|
||||
assert len(w) == prev_len
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"):
|
||||
warn()
|
||||
assert len(w) > 0
|
||||
msg = str(w[-1].message)
|
||||
expected_prefix = "Explicitly requested dtype "
|
||||
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
|
||||
|
||||
prev_len = len(w)
|
||||
with self.assertNoWarnings():
|
||||
nowarn()
|
||||
assert len(w) == prev_len
|
||||
|
||||
check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
|
||||
lambda: jnp.array([1, 2, 3], dtype="float32"))
|
||||
|
@ -14,7 +14,6 @@
|
||||
import functools
|
||||
import threading
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
@ -535,8 +534,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
with jtu.ignore_warning():
|
||||
f(jnp.arange(jax.device_count())) # doesn't crash
|
||||
|
||||
def test_cant_jit_and_pmap_function_with_ordered_effects(self):
|
||||
|
@ -18,7 +18,6 @@ from functools import partial
|
||||
import itertools
|
||||
import typing
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -1524,10 +1523,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
def testIndexDtypeError(self):
|
||||
# https://github.com/google/jax/issues/2795
|
||||
jnp.array(1) # get rid of startup warning
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("error")
|
||||
with self.assertNoWarnings():
|
||||
jnp.zeros(5).at[::2].set(1)
|
||||
self.assertLen(w, 0)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(idx=idx, idx_type=idx_type)
|
||||
|
@ -25,7 +25,6 @@ import platform
|
||||
from typing import cast, Optional
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -4564,8 +4563,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]])
|
||||
|
||||
# matrix directives
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
||||
with jtu.ignore_warning(category=PendingDeprecationWarning):
|
||||
self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]])
|
||||
self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]])
|
||||
|
||||
@ -4613,8 +4611,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]])
|
||||
self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]])
|
||||
# matrix directives, avoid numpy deprecation warning
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
||||
with jtu.ignore_warning(category=PendingDeprecationWarning):
|
||||
self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]])
|
||||
self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]])
|
||||
|
||||
@ -5497,8 +5494,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
|
||||
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
|
||||
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning)
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"):
|
||||
_ = func(*args)
|
||||
except TypeError:
|
||||
pass
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import functools
|
||||
import math
|
||||
import warnings
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from absl import flags
|
||||
@ -929,14 +928,9 @@ class MemoriesComputationTest(jtu.BufferDonationTestCase):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"):
|
||||
f(inp)
|
||||
|
||||
self.assertLen(w, 1)
|
||||
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||
self.assertIn("Some donated buffers were not usable:", str(w[-1].message))
|
||||
|
||||
lowered_text = f.lower(inp).as_text("hlo")
|
||||
self.assertNotIn("input_output_alias", lowered_text)
|
||||
self.assertNotDeleted(inp)
|
||||
|
@ -24,7 +24,6 @@ import re
|
||||
from typing import Optional, cast
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -1848,14 +1847,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def foo(x): return x
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with self.assertWarnsRegex(UserWarning, "The jitted function foo includes a pmap"):
|
||||
jit(self.pmap(foo))(jnp.arange(device_count))
|
||||
|
||||
self.assertGreaterEqual(len(w), 1)
|
||||
self.assertIn("The jitted function foo includes a pmap",
|
||||
str(w[-1].message))
|
||||
|
||||
def testJitOfPmapOutputSharding(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
|
@ -355,12 +355,8 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
xb.get_backend("none")
|
||||
|
||||
def cpu_fallback_warning(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with self.assertWarnsRegex(UserWarning, "No GPU/TPU found, falling back to CPU"):
|
||||
xb.get_backend()
|
||||
self.assertLen(w, 1)
|
||||
msg = str(w[-1].message)
|
||||
self.assertIn("No GPU/TPU found, falling back to CPU", msg)
|
||||
|
||||
def test_jax_platforms_flag(self):
|
||||
self._register_factory("platform_A", 20, assert_used_at_most_once=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user