mirror of
https://github.com/ROCm/jax.git
synced 2025-04-22 21:06:04 +00:00
223 lines
7.1 KiB
Python
223 lines
7.1 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Contains a custom unittest loader and test suite.
|
|
|
|
Implements:
|
|
- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS
|
|
environment variables.
|
|
- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS
|
|
is >= 1.
|
|
- Test decorators that mark a test case or test class as thread-hostile.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import contextmanager
|
|
import logging
|
|
import os
|
|
import re
|
|
import threading
|
|
import time
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from jax._src import config
|
|
from jax._src import test_warning_util
|
|
from jax._src import util
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_TEST_TARGETS = config.string_flag(
|
|
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
|
'Regular expression specifying which tests to run, called via re.search on '
|
|
'the test name. If empty or unspecified, run all tests.'
|
|
)
|
|
|
|
_EXCLUDE_TEST_TARGETS = config.string_flag(
|
|
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
|
'Regular expression specifying which tests NOT to run, called via re.search '
|
|
'on the test name. If empty or unspecified, run all tests.'
|
|
)
|
|
|
|
TEST_NUM_THREADS = config.int_flag(
|
|
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
|
|
help='Number of threads to use for running tests. 0 means run everything '
|
|
'in the main thread. Using > 1 thread is experimental.'
|
|
)
|
|
|
|
# We use a reader-writer lock to protect test execution. Tests that may run in
|
|
# parallel acquire a read lock; tests that are not thread-safe acquire a write
|
|
# lock.
|
|
_test_rwlock = util.Mutex()
|
|
|
|
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
|
if getattr(test.__class__, "thread_hostile", False):
|
|
_test_rwlock.writer_lock()
|
|
try:
|
|
test(result) # type: ignore
|
|
finally:
|
|
_test_rwlock.writer_unlock()
|
|
else:
|
|
_test_rwlock.reader_lock()
|
|
try:
|
|
test(result) # type: ignore
|
|
finally:
|
|
_test_rwlock.reader_unlock()
|
|
|
|
|
|
@contextmanager
|
|
def thread_unsafe_test():
|
|
"""Decorator for tests that are not thread-safe.
|
|
|
|
Note: this decorator (naturally) only applies to what it wraps, not to, say,
|
|
code in separate setUp() or tearDown() methods.
|
|
"""
|
|
if TEST_NUM_THREADS.value <= 0:
|
|
yield
|
|
return
|
|
|
|
_test_rwlock.assert_reader_held()
|
|
_test_rwlock.reader_unlock()
|
|
_test_rwlock.writer_lock()
|
|
try:
|
|
yield
|
|
finally:
|
|
_test_rwlock.writer_unlock()
|
|
_test_rwlock.reader_lock()
|
|
|
|
|
|
def thread_unsafe_test_class():
|
|
"""Decorator that marks a TestCase class as thread-hostile."""
|
|
def f(klass):
|
|
assert issubclass(klass, unittest.TestCase), type(klass)
|
|
klass.thread_hostile = True
|
|
return klass
|
|
return f
|
|
|
|
|
|
class ThreadSafeTestResult:
|
|
"""
|
|
Wraps a TestResult to make it thread safe.
|
|
|
|
We do this by accumulating API calls and applying them in a batch under a
|
|
lock at the conclusion of each test case.
|
|
|
|
We duck type instead of inheriting from TestResult because we aren't actually
|
|
a perfect implementation of TestResult, and would rather get a loud error
|
|
for things we haven't implemented.
|
|
"""
|
|
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
|
|
self.lock = lock
|
|
self.test_result = result
|
|
self.actions: list[Callable[[], None]] = []
|
|
|
|
def startTest(self, test: unittest.TestCase):
|
|
logger.info("Test start: %s", test.id())
|
|
self.start_time = time.time()
|
|
|
|
def stopTest(self, test: unittest.TestCase):
|
|
logger.info("Test stop: %s", test.id())
|
|
stop_time = time.time()
|
|
with self.lock:
|
|
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
|
|
# the time. This affects the timing that shows up in the XML output
|
|
# consumed by CI.
|
|
time_getter = getattr(self.test_result, "time_getter", None)
|
|
try:
|
|
self.test_result.time_getter = lambda: self.start_time
|
|
self.test_result.startTest(test)
|
|
for callback in self.actions:
|
|
callback()
|
|
self.test_result.time_getter = lambda: stop_time
|
|
self.test_result.stopTest(test)
|
|
finally:
|
|
if time_getter is not None:
|
|
self.test_result.time_getter = time_getter
|
|
|
|
def addSuccess(self, test: unittest.TestCase):
|
|
self.actions.append(lambda: self.test_result.addSuccess(test))
|
|
|
|
def addSkip(self, test: unittest.TestCase, reason: str):
|
|
self.actions.append(lambda: self.test_result.addSkip(test, reason))
|
|
|
|
def addError(self, test: unittest.TestCase, err):
|
|
self.actions.append(lambda: self.test_result.addError(test, err))
|
|
|
|
def addFailure(self, test: unittest.TestCase, err):
|
|
self.actions.append(lambda: self.test_result.addFailure(test, err))
|
|
|
|
def addExpectedFailure(self, test: unittest.TestCase, err):
|
|
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
|
|
|
|
def addDuration(self, test: unittest.TestCase, elapsed):
|
|
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
|
|
|
|
|
|
class JaxTestSuite(unittest.TestSuite):
|
|
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
|
|
|
|
Caution: this test suite does not run setUpClass or setUpModule methods if
|
|
thread parallelism is enabled.
|
|
"""
|
|
|
|
def __init__(self, suite: unittest.TestSuite):
|
|
super().__init__(list(suite))
|
|
|
|
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
|
|
if TEST_NUM_THREADS.value <= 0:
|
|
return super().run(result)
|
|
|
|
test_warning_util.install_threadsafe_warning_handlers()
|
|
|
|
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
|
|
lock = threading.Lock()
|
|
futures = []
|
|
|
|
def run_test(test):
|
|
"""Recursively runs tests in a test suite or test case."""
|
|
if isinstance(test, unittest.TestSuite):
|
|
for subtest in test:
|
|
run_test(subtest)
|
|
else:
|
|
test_result = ThreadSafeTestResult(lock, result)
|
|
futures.append(executor.submit(_run_one_test, test, test_result))
|
|
|
|
with executor:
|
|
run_test(self)
|
|
for future in futures:
|
|
future.result()
|
|
|
|
return result
|
|
|
|
|
|
class JaxTestLoader(absltest.TestLoader):
|
|
suiteClass = JaxTestSuite
|
|
|
|
def getTestCaseNames(self, testCaseClass):
|
|
names = super().getTestCaseNames(testCaseClass)
|
|
if _TEST_TARGETS.value:
|
|
pattern = re.compile(_TEST_TARGETS.value)
|
|
names = [name for name in names
|
|
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
if _EXCLUDE_TEST_TARGETS.value:
|
|
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
|
|
names = [name for name in names
|
|
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
return names
|