rocm_jax/jax/_src/test_loader.py
Peter Hawkins b4c3e38022 When running test cases concurrently, log the start and end of each test case.
This is very helpful for debugging deadlocks!

PiperOrigin-RevId: 745986596
2025-04-10 05:26:02 -07:00

223 lines
7.1 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a custom unittest loader and test suite.
Implements:
- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS
environment variables.
- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS
is >= 1.
- Test decorators that mark a test case or test class as thread-hostile.
"""
from __future__ import annotations
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
import logging
import os
import re
import threading
import time
import unittest
from absl.testing import absltest
from jax._src import config
from jax._src import test_warning_util
from jax._src import util
logger = logging.getLogger(__name__)
_TEST_TARGETS = config.string_flag(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
_EXCLUDE_TEST_TARGETS = config.string_flag(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)
TEST_NUM_THREADS = config.int_flag(
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
help='Number of threads to use for running tests. 0 means run everything '
'in the main thread. Using > 1 thread is experimental.'
)
# We use a reader-writer lock to protect test execution. Tests that may run in
# parallel acquire a read lock; tests that are not thread-safe acquire a write
# lock.
_test_rwlock = util.Mutex()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
if getattr(test.__class__, "thread_hostile", False):
_test_rwlock.writer_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager
def thread_unsafe_test():
"""Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0:
yield
return
_test_rwlock.assert_reader_held()
_test_rwlock.reader_unlock()
_test_rwlock.writer_lock()
try:
yield
finally:
_test_rwlock.writer_unlock()
_test_rwlock.reader_lock()
def thread_unsafe_test_class():
"""Decorator that marks a TestCase class as thread-hostile."""
def f(klass):
assert issubclass(klass, unittest.TestCase), type(klass)
klass.thread_hostile = True
return klass
return f
class ThreadSafeTestResult:
"""
Wraps a TestResult to make it thread safe.
We do this by accumulating API calls and applying them in a batch under a
lock at the conclusion of each test case.
We duck type instead of inheriting from TestResult because we aren't actually
a perfect implementation of TestResult, and would rather get a loud error
for things we haven't implemented.
"""
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
self.lock = lock
self.test_result = result
self.actions: list[Callable[[], None]] = []
def startTest(self, test: unittest.TestCase):
logger.info("Test start: %s", test.id())
self.start_time = time.time()
def stopTest(self, test: unittest.TestCase):
logger.info("Test stop: %s", test.id())
stop_time = time.time()
with self.lock:
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
# the time. This affects the timing that shows up in the XML output
# consumed by CI.
time_getter = getattr(self.test_result, "time_getter", None)
try:
self.test_result.time_getter = lambda: self.start_time
self.test_result.startTest(test)
for callback in self.actions:
callback()
self.test_result.time_getter = lambda: stop_time
self.test_result.stopTest(test)
finally:
if time_getter is not None:
self.test_result.time_getter = time_getter
def addSuccess(self, test: unittest.TestCase):
self.actions.append(lambda: self.test_result.addSuccess(test))
def addSkip(self, test: unittest.TestCase, reason: str):
self.actions.append(lambda: self.test_result.addSkip(test, reason))
def addError(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addError(test, err))
def addFailure(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addFailure(test, err))
def addExpectedFailure(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
def addDuration(self, test: unittest.TestCase, elapsed):
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
class JaxTestSuite(unittest.TestSuite):
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
Caution: this test suite does not run setUpClass or setUpModule methods if
thread parallelism is enabled.
"""
def __init__(self, suite: unittest.TestSuite):
super().__init__(list(suite))
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
if TEST_NUM_THREADS.value <= 0:
return super().run(result)
test_warning_util.install_threadsafe_warning_handlers()
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
lock = threading.Lock()
futures = []
def run_test(test):
"""Recursively runs tests in a test suite or test case."""
if isinstance(test, unittest.TestSuite):
for subtest in test:
run_test(subtest)
else:
test_result = ThreadSafeTestResult(lock, result)
futures.append(executor.submit(_run_one_test, test, test_result))
with executor:
run_test(self)
for future in futures:
future.result()
return result
class JaxTestLoader(absltest.TestLoader):
suiteClass = JaxTestSuite
def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass)
if _TEST_TARGETS.value:
pattern = re.compile(_TEST_TARGETS.value)
names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")]
if _EXCLUDE_TEST_TARGETS.value:
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
names = [name for name in names
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
return names