1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 00:36:06 +00:00
rocm_jax/tests/deprecation_test.py
Peter Hawkins b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00

58 lines
2.0 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src import test_warning_util
from jax._src.internal_test_util import deprecation_module as m
class DeprecationTest(absltest.TestCase):
def testModuleDeprecation(self):
with test_warning_util.raise_on_warnings():
self.assertEqual(m.x, 42)
with self.assertWarnsRegex(DeprecationWarning, "Please use x"):
self.assertEqual(m.y, 101)
with self.assertRaisesRegex(AttributeError, "Please do not use z"):
_ = m.z
with self.assertRaisesRegex(AttributeError,
"module .* has no attribute 'w'"):
_ = m.w
def testNamedDeprecation(self):
some_unique_id = "some-unique-id"
try:
deprecations.register(some_unique_id)
self.assertFalse(deprecations.is_accelerated(some_unique_id))
deprecations.accelerate(some_unique_id)
self.assertTrue(deprecations.is_accelerated(some_unique_id))
finally:
deprecations.unregister(some_unique_id)
msg = f"deprecation_id={some_unique_id!r} not registered"
with self.assertRaisesRegex(ValueError, msg):
deprecations.accelerate(some_unique_id)
with self.assertRaisesRegex(ValueError, msg):
deprecations.is_accelerated(some_unique_id)
with self.assertRaisesRegex(ValueError, msg):
deprecations.unregister(some_unique_id)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())