mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Split JaxTestLoader and related classes into a separate file.
Refactoring only, no functional changes intended. PiperOrigin-RevId: 745813442
This commit is contained in:
parent
cf268a7f6a
commit
382285d315
@ -142,6 +142,7 @@ py_library(
|
|||||||
# these are available in jax.test_util via the standard :jax target.
|
# these are available in jax.test_util via the standard :jax target.
|
||||||
name = "test_util",
|
name = "test_util",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"_src/test_loader.py",
|
||||||
"_src/test_util.py",
|
"_src/test_util.py",
|
||||||
"_src/test_warning_util.py",
|
"_src/test_warning_util.py",
|
||||||
],
|
],
|
||||||
|
218
jax/_src/test_loader.py
Normal file
218
jax/_src/test_loader.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# Copyright 2018 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Contains a custom unittest loader and test suite.
|
||||||
|
|
||||||
|
Implements:
|
||||||
|
- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS
|
||||||
|
environment variables.
|
||||||
|
- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS
|
||||||
|
is >= 1.
|
||||||
|
- Test decorators that mark a test case or test class as thread-hostile.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from jax._src import config
|
||||||
|
from jax._src import test_warning_util
|
||||||
|
from jax._src import util
|
||||||
|
|
||||||
|
|
||||||
|
_TEST_TARGETS = config.string_flag(
|
||||||
|
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
||||||
|
'Regular expression specifying which tests to run, called via re.search on '
|
||||||
|
'the test name. If empty or unspecified, run all tests.'
|
||||||
|
)
|
||||||
|
|
||||||
|
_EXCLUDE_TEST_TARGETS = config.string_flag(
|
||||||
|
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
||||||
|
'Regular expression specifying which tests NOT to run, called via re.search '
|
||||||
|
'on the test name. If empty or unspecified, run all tests.'
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_NUM_THREADS = config.int_flag(
|
||||||
|
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
|
||||||
|
help='Number of threads to use for running tests. 0 means run everything '
|
||||||
|
'in the main thread. Using > 1 thread is experimental.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use a reader-writer lock to protect test execution. Tests that may run in
|
||||||
|
# parallel acquire a read lock; tests that are not thread-safe acquire a write
|
||||||
|
# lock.
|
||||||
|
_test_rwlock = util.Mutex()
|
||||||
|
|
||||||
|
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
||||||
|
if getattr(test.__class__, "thread_hostile", False):
|
||||||
|
_test_rwlock.writer_lock()
|
||||||
|
try:
|
||||||
|
test(result) # type: ignore
|
||||||
|
finally:
|
||||||
|
_test_rwlock.writer_unlock()
|
||||||
|
else:
|
||||||
|
_test_rwlock.reader_lock()
|
||||||
|
try:
|
||||||
|
test(result) # type: ignore
|
||||||
|
finally:
|
||||||
|
_test_rwlock.reader_unlock()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def thread_unsafe_test():
|
||||||
|
"""Decorator for tests that are not thread-safe.
|
||||||
|
|
||||||
|
Note: this decorator (naturally) only applies to what it wraps, not to, say,
|
||||||
|
code in separate setUp() or tearDown() methods.
|
||||||
|
"""
|
||||||
|
if TEST_NUM_THREADS.value <= 0:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
_test_rwlock.assert_reader_held()
|
||||||
|
_test_rwlock.reader_unlock()
|
||||||
|
_test_rwlock.writer_lock()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_test_rwlock.writer_unlock()
|
||||||
|
_test_rwlock.reader_lock()
|
||||||
|
|
||||||
|
|
||||||
|
def thread_unsafe_test_class():
|
||||||
|
"""Decorator that marks a TestCase class as thread-hostile."""
|
||||||
|
def f(klass):
|
||||||
|
assert issubclass(klass, unittest.TestCase), type(klass)
|
||||||
|
klass.thread_hostile = True
|
||||||
|
return klass
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSafeTestResult:
|
||||||
|
"""
|
||||||
|
Wraps a TestResult to make it thread safe.
|
||||||
|
|
||||||
|
We do this by accumulating API calls and applying them in a batch under a
|
||||||
|
lock at the conclusion of each test case.
|
||||||
|
|
||||||
|
We duck type instead of inheriting from TestResult because we aren't actually
|
||||||
|
a perfect implementation of TestResult, and would rather get a loud error
|
||||||
|
for things we haven't implemented.
|
||||||
|
"""
|
||||||
|
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
|
||||||
|
self.lock = lock
|
||||||
|
self.test_result = result
|
||||||
|
self.actions: list[Callable[[], None]] = []
|
||||||
|
|
||||||
|
def startTest(self, test: unittest.TestCase):
|
||||||
|
del test
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def stopTest(self, test: unittest.TestCase):
|
||||||
|
stop_time = time.time()
|
||||||
|
with self.lock:
|
||||||
|
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
|
||||||
|
# the time. This affects the timing that shows up in the XML output
|
||||||
|
# consumed by CI.
|
||||||
|
time_getter = getattr(self.test_result, "time_getter", None)
|
||||||
|
try:
|
||||||
|
self.test_result.time_getter = lambda: self.start_time
|
||||||
|
self.test_result.startTest(test)
|
||||||
|
for callback in self.actions:
|
||||||
|
callback()
|
||||||
|
self.test_result.time_getter = lambda: stop_time
|
||||||
|
self.test_result.stopTest(test)
|
||||||
|
finally:
|
||||||
|
if time_getter is not None:
|
||||||
|
self.test_result.time_getter = time_getter
|
||||||
|
|
||||||
|
def addSuccess(self, test: unittest.TestCase):
|
||||||
|
self.actions.append(lambda: self.test_result.addSuccess(test))
|
||||||
|
|
||||||
|
def addSkip(self, test: unittest.TestCase, reason: str):
|
||||||
|
self.actions.append(lambda: self.test_result.addSkip(test, reason))
|
||||||
|
|
||||||
|
def addError(self, test: unittest.TestCase, err):
|
||||||
|
self.actions.append(lambda: self.test_result.addError(test, err))
|
||||||
|
|
||||||
|
def addFailure(self, test: unittest.TestCase, err):
|
||||||
|
self.actions.append(lambda: self.test_result.addFailure(test, err))
|
||||||
|
|
||||||
|
def addExpectedFailure(self, test: unittest.TestCase, err):
|
||||||
|
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
|
||||||
|
|
||||||
|
def addDuration(self, test: unittest.TestCase, elapsed):
|
||||||
|
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
|
||||||
|
|
||||||
|
|
||||||
|
class JaxTestSuite(unittest.TestSuite):
|
||||||
|
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
|
||||||
|
|
||||||
|
Caution: this test suite does not run setUpClass or setUpModule methods if
|
||||||
|
thread parallelism is enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, suite: unittest.TestSuite):
|
||||||
|
super().__init__(list(suite))
|
||||||
|
|
||||||
|
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
|
||||||
|
if TEST_NUM_THREADS.value <= 0:
|
||||||
|
return super().run(result)
|
||||||
|
|
||||||
|
test_warning_util.install_threadsafe_warning_handlers()
|
||||||
|
|
||||||
|
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
|
||||||
|
lock = threading.Lock()
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
def run_test(test):
|
||||||
|
"""Recursively runs tests in a test suite or test case."""
|
||||||
|
if isinstance(test, unittest.TestSuite):
|
||||||
|
for subtest in test:
|
||||||
|
run_test(subtest)
|
||||||
|
else:
|
||||||
|
test_result = ThreadSafeTestResult(lock, result)
|
||||||
|
futures.append(executor.submit(_run_one_test, test, test_result))
|
||||||
|
|
||||||
|
with executor:
|
||||||
|
run_test(self)
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class JaxTestLoader(absltest.TestLoader):
|
||||||
|
suiteClass = JaxTestSuite
|
||||||
|
|
||||||
|
def getTestCaseNames(self, testCaseClass):
|
||||||
|
names = super().getTestCaseNames(testCaseClass)
|
||||||
|
if _TEST_TARGETS.value:
|
||||||
|
pattern = re.compile(_TEST_TARGETS.value)
|
||||||
|
names = [name for name in names
|
||||||
|
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||||
|
if _EXCLUDE_TEST_TARGETS.value:
|
||||||
|
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
|
||||||
|
names = [name for name in names
|
||||||
|
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||||
|
return names
|
@ -17,7 +17,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
@ -32,12 +31,10 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from typing import Any, TextIO
|
from typing import Any, TextIO
|
||||||
import unittest
|
import unittest
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
from absl.testing import absltest
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@ -63,12 +60,17 @@ from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
|||||||
from jax._src.public_test_util import ( # noqa: F401
|
from jax._src.public_test_util import ( # noqa: F401
|
||||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||||
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict)
|
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict)
|
||||||
|
from jax._src.test_loader import thread_unsafe_test as thread_unsafe_test
|
||||||
|
from jax._src.test_loader import thread_unsafe_test_class as thread_unsafe_test_class
|
||||||
|
from jax._src.test_loader import JaxTestLoader as JaxTestLoader
|
||||||
|
from jax._src.test_loader import TEST_NUM_THREADS as TEST_NUM_THREADS
|
||||||
from jax._src.util import unzip2
|
from jax._src.util import unzip2
|
||||||
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
|
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.random as npr
|
import numpy.random as npr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# This submodule includes private test utilities that are not exported to
|
# This submodule includes private test utilities that are not exported to
|
||||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||||
# may be changed or removed at any time and without any deprecation cycle.
|
# may be changed or removed at any time and without any deprecation cycle.
|
||||||
@ -98,16 +100,6 @@ SKIP_SLOW_TESTS = config.bool_flag(
|
|||||||
help='Skip tests marked as slow (> 5 sec).'
|
help='Skip tests marked as slow (> 5 sec).'
|
||||||
)
|
)
|
||||||
|
|
||||||
_TEST_TARGETS = config.string_flag(
|
|
||||||
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
|
||||||
'Regular expression specifying which tests to run, called via re.search on '
|
|
||||||
'the test name. If empty or unspecified, run all tests.'
|
|
||||||
)
|
|
||||||
_EXCLUDE_TEST_TARGETS = config.string_flag(
|
|
||||||
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
|
||||||
'Regular expression specifying which tests NOT to run, called via re.search '
|
|
||||||
'on the test name. If empty or unspecified, run all tests.'
|
|
||||||
)
|
|
||||||
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
|
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
|
||||||
'jax_test_with_persistent_compilation_cache',
|
'jax_test_with_persistent_compilation_cache',
|
||||||
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
||||||
@ -121,11 +113,6 @@ HYPOTHESIS_PROFILE = config.string_flag(
|
|||||||
'deterministic, interactive'),
|
'deterministic, interactive'),
|
||||||
)
|
)
|
||||||
|
|
||||||
TEST_NUM_THREADS = config.int_flag(
|
|
||||||
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
|
|
||||||
help='Number of threads to use for running tests. 0 means run everything '
|
|
||||||
'in the main thread. Using > 1 thread is experimental.'
|
|
||||||
)
|
|
||||||
|
|
||||||
# We sanitize test names to ensure they work with "unitttest -k" and
|
# We sanitize test names to ensure they work with "unitttest -k" and
|
||||||
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
|
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
|
||||||
@ -1074,165 +1061,6 @@ def sample_product(*args, **kw):
|
|||||||
"""
|
"""
|
||||||
return parameterized.parameters(*sample_product_testcases(*args, **kw))
|
return parameterized.parameters(*sample_product_testcases(*args, **kw))
|
||||||
|
|
||||||
# We use a reader-writer lock to protect test execution. Tests that may run in
|
|
||||||
# parallel acquire a read lock; tests that are not thread-safe acquire a write
|
|
||||||
# lock.
|
|
||||||
_test_rwlock = util.Mutex()
|
|
||||||
|
|
||||||
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
|
||||||
if getattr(test.__class__, "thread_hostile", False):
|
|
||||||
_test_rwlock.writer_lock()
|
|
||||||
try:
|
|
||||||
test(result) # type: ignore
|
|
||||||
finally:
|
|
||||||
_test_rwlock.writer_unlock()
|
|
||||||
else:
|
|
||||||
_test_rwlock.reader_lock()
|
|
||||||
try:
|
|
||||||
test(result) # type: ignore
|
|
||||||
finally:
|
|
||||||
_test_rwlock.reader_unlock()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def thread_unsafe_test():
|
|
||||||
"""Decorator for tests that are not thread-safe.
|
|
||||||
|
|
||||||
Note: this decorator (naturally) only applies to what it wraps, not to, say,
|
|
||||||
code in separate setUp() or tearDown() methods.
|
|
||||||
"""
|
|
||||||
if TEST_NUM_THREADS.value <= 0:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
_test_rwlock.assert_reader_held()
|
|
||||||
_test_rwlock.reader_unlock()
|
|
||||||
_test_rwlock.writer_lock()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
_test_rwlock.writer_unlock()
|
|
||||||
_test_rwlock.reader_lock()
|
|
||||||
|
|
||||||
|
|
||||||
def thread_unsafe_test_class():
|
|
||||||
"Decorator that marks a TestCase class as thread-hostile."
|
|
||||||
def f(klass):
|
|
||||||
assert issubclass(klass, unittest.TestCase), type(klass)
|
|
||||||
klass.thread_hostile = True
|
|
||||||
return klass
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadSafeTestResult:
|
|
||||||
"""
|
|
||||||
Wraps a TestResult to make it thread safe.
|
|
||||||
|
|
||||||
We do this by accumulating API calls and applying them in a batch under a
|
|
||||||
lock at the conclusion of each test case.
|
|
||||||
|
|
||||||
We duck type instead of inheriting from TestResult because we aren't actually
|
|
||||||
a perfect implementation of TestResult, and would rather get a loud error
|
|
||||||
for things we haven't implemented.
|
|
||||||
"""
|
|
||||||
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
|
|
||||||
self.lock = lock
|
|
||||||
self.test_result = result
|
|
||||||
self.actions: list[Callable] = []
|
|
||||||
|
|
||||||
def startTest(self, test: unittest.TestCase):
|
|
||||||
del test
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def stopTest(self, test: unittest.TestCase):
|
|
||||||
stop_time = time.time()
|
|
||||||
with self.lock:
|
|
||||||
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
|
|
||||||
# the time. This affects the timing that shows up in the XML output
|
|
||||||
# consumed by CI.
|
|
||||||
time_getter = getattr(self.test_result, "time_getter", None)
|
|
||||||
try:
|
|
||||||
self.test_result.time_getter = lambda: self.start_time
|
|
||||||
self.test_result.startTest(test)
|
|
||||||
for callback in self.actions:
|
|
||||||
callback()
|
|
||||||
self.test_result.time_getter = lambda: stop_time
|
|
||||||
self.test_result.stopTest(test)
|
|
||||||
finally:
|
|
||||||
if time_getter is not None:
|
|
||||||
self.test_result.time_getter = time_getter
|
|
||||||
|
|
||||||
def addSuccess(self, test: unittest.TestCase):
|
|
||||||
self.actions.append(lambda: self.test_result.addSuccess(test))
|
|
||||||
|
|
||||||
def addSkip(self, test: unittest.TestCase, reason: str):
|
|
||||||
self.actions.append(lambda: self.test_result.addSkip(test, reason))
|
|
||||||
|
|
||||||
def addError(self, test: unittest.TestCase, err):
|
|
||||||
self.actions.append(lambda: self.test_result.addError(test, err))
|
|
||||||
|
|
||||||
def addFailure(self, test: unittest.TestCase, err):
|
|
||||||
self.actions.append(lambda: self.test_result.addFailure(test, err))
|
|
||||||
|
|
||||||
def addExpectedFailure(self, test: unittest.TestCase, err):
|
|
||||||
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
|
|
||||||
|
|
||||||
def addDuration(self, test: unittest.TestCase, elapsed):
|
|
||||||
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
|
|
||||||
|
|
||||||
|
|
||||||
class JaxTestSuite(unittest.TestSuite):
|
|
||||||
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
|
|
||||||
|
|
||||||
Caution: this test suite does not run setUpClass or setUpModule methods if
|
|
||||||
thread parallelism is enabled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, suite: unittest.TestSuite):
|
|
||||||
super().__init__(list(suite))
|
|
||||||
|
|
||||||
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
|
|
||||||
if TEST_NUM_THREADS.value <= 0:
|
|
||||||
return super().run(result)
|
|
||||||
|
|
||||||
test_warning_util.install_threadsafe_warning_handlers()
|
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
|
|
||||||
lock = threading.Lock()
|
|
||||||
futures = []
|
|
||||||
|
|
||||||
def run_test(test):
|
|
||||||
"Recursively runs tests in a test suite or test case."
|
|
||||||
if isinstance(test, unittest.TestSuite):
|
|
||||||
for subtest in test:
|
|
||||||
run_test(subtest)
|
|
||||||
else:
|
|
||||||
test_result = ThreadSafeTestResult(lock, result)
|
|
||||||
futures.append(executor.submit(_run_one_test, test, test_result))
|
|
||||||
|
|
||||||
with executor:
|
|
||||||
run_test(self)
|
|
||||||
for future in futures:
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class JaxTestLoader(absltest.TestLoader):
|
|
||||||
suiteClass = JaxTestSuite
|
|
||||||
|
|
||||||
def getTestCaseNames(self, testCaseClass):
|
|
||||||
names = super().getTestCaseNames(testCaseClass)
|
|
||||||
if _TEST_TARGETS.value:
|
|
||||||
pattern = re.compile(_TEST_TARGETS.value)
|
|
||||||
names = [name for name in names
|
|
||||||
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
||||||
if _EXCLUDE_TEST_TARGETS.value:
|
|
||||||
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
|
|
||||||
names = [name for name in names
|
|
||||||
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
||||||
return names
|
|
||||||
|
|
||||||
|
|
||||||
def with_config(**kwds):
|
def with_config(**kwds):
|
||||||
"""Test case decorator for subclasses of JaxTestCase"""
|
"""Test case decorator for subclasses of JaxTestCase"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user