Split JaxTestLoader and related classes into a separate file.

Refactoring only, no functional changes intended.

PiperOrigin-RevId: 745813442
This commit is contained in:
Peter Hawkins 2025-04-09 18:44:41 -07:00 committed by jax authors
parent cf268a7f6a
commit 382285d315
3 changed files with 224 additions and 177 deletions

View File

@ -142,6 +142,7 @@ py_library(
# these are available in jax.test_util via the standard :jax target.
name = "test_util",
srcs = [
"_src/test_loader.py",
"_src/test_util.py",
"_src/test_warning_util.py",
],

218
jax/_src/test_loader.py Normal file
View File

@ -0,0 +1,218 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a custom unittest loader and test suite.
Implements:
- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS
environment variables.
- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS
is >= 1.
- Test decorators that mark a test case or test class as thread-hostile.
"""
from __future__ import annotations
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
import os
import re
import threading
import time
import unittest
from absl.testing import absltest
from jax._src import config
from jax._src import test_warning_util
from jax._src import util
_TEST_TARGETS = config.string_flag(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
_EXCLUDE_TEST_TARGETS = config.string_flag(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)
TEST_NUM_THREADS = config.int_flag(
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
help='Number of threads to use for running tests. 0 means run everything '
'in the main thread. Using > 1 thread is experimental.'
)
# We use a reader-writer lock to protect test execution. Tests that may run in
# parallel acquire a read lock; tests that are not thread-safe acquire a write
# lock.
_test_rwlock = util.Mutex()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
if getattr(test.__class__, "thread_hostile", False):
_test_rwlock.writer_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager
def thread_unsafe_test():
"""Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0:
yield
return
_test_rwlock.assert_reader_held()
_test_rwlock.reader_unlock()
_test_rwlock.writer_lock()
try:
yield
finally:
_test_rwlock.writer_unlock()
_test_rwlock.reader_lock()
def thread_unsafe_test_class():
"""Decorator that marks a TestCase class as thread-hostile."""
def f(klass):
assert issubclass(klass, unittest.TestCase), type(klass)
klass.thread_hostile = True
return klass
return f
class ThreadSafeTestResult:
"""
Wraps a TestResult to make it thread safe.
We do this by accumulating API calls and applying them in a batch under a
lock at the conclusion of each test case.
We duck type instead of inheriting from TestResult because we aren't actually
a perfect implementation of TestResult, and would rather get a loud error
for things we haven't implemented.
"""
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
self.lock = lock
self.test_result = result
self.actions: list[Callable[[], None]] = []
def startTest(self, test: unittest.TestCase):
del test
self.start_time = time.time()
def stopTest(self, test: unittest.TestCase):
stop_time = time.time()
with self.lock:
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
# the time. This affects the timing that shows up in the XML output
# consumed by CI.
time_getter = getattr(self.test_result, "time_getter", None)
try:
self.test_result.time_getter = lambda: self.start_time
self.test_result.startTest(test)
for callback in self.actions:
callback()
self.test_result.time_getter = lambda: stop_time
self.test_result.stopTest(test)
finally:
if time_getter is not None:
self.test_result.time_getter = time_getter
def addSuccess(self, test: unittest.TestCase):
self.actions.append(lambda: self.test_result.addSuccess(test))
def addSkip(self, test: unittest.TestCase, reason: str):
self.actions.append(lambda: self.test_result.addSkip(test, reason))
def addError(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addError(test, err))
def addFailure(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addFailure(test, err))
def addExpectedFailure(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
def addDuration(self, test: unittest.TestCase, elapsed):
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
class JaxTestSuite(unittest.TestSuite):
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
Caution: this test suite does not run setUpClass or setUpModule methods if
thread parallelism is enabled.
"""
def __init__(self, suite: unittest.TestSuite):
super().__init__(list(suite))
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
if TEST_NUM_THREADS.value <= 0:
return super().run(result)
test_warning_util.install_threadsafe_warning_handlers()
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
lock = threading.Lock()
futures = []
def run_test(test):
"""Recursively runs tests in a test suite or test case."""
if isinstance(test, unittest.TestSuite):
for subtest in test:
run_test(subtest)
else:
test_result = ThreadSafeTestResult(lock, result)
futures.append(executor.submit(_run_one_test, test, test_result))
with executor:
run_test(self)
for future in futures:
future.result()
return result
class JaxTestLoader(absltest.TestLoader):
suiteClass = JaxTestSuite
def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass)
if _TEST_TARGETS.value:
pattern = re.compile(_TEST_TARGETS.value)
names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")]
if _EXCLUDE_TEST_TARGETS.value:
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
names = [name for name in names
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
return names

View File

@ -17,7 +17,6 @@ from __future__ import annotations
import collections
from collections.abc import Callable, Generator, Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
import datetime
import functools
@ -32,12 +31,10 @@ import sys
import tempfile
import textwrap
import threading
import time
from typing import Any, TextIO
import unittest
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
@ -63,12 +60,17 @@ from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict)
from jax._src.test_loader import thread_unsafe_test as thread_unsafe_test
from jax._src.test_loader import thread_unsafe_test_class as thread_unsafe_test_class
from jax._src.test_loader import JaxTestLoader as JaxTestLoader
from jax._src.test_loader import TEST_NUM_THREADS as TEST_NUM_THREADS
from jax._src.util import unzip2
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
import numpy as np
import numpy.random as npr
# This submodule includes private test utilities that are not exported to
# jax.test_util. Functionality appearing here is for internal use only, and
# may be changed or removed at any time and without any deprecation cycle.
@ -98,16 +100,6 @@ SKIP_SLOW_TESTS = config.bool_flag(
help='Skip tests marked as slow (> 5 sec).'
)
_TEST_TARGETS = config.string_flag(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
_EXCLUDE_TEST_TARGETS = config.string_flag(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
'jax_test_with_persistent_compilation_cache',
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
@ -121,11 +113,6 @@ HYPOTHESIS_PROFILE = config.string_flag(
'deterministic, interactive'),
)
TEST_NUM_THREADS = config.int_flag(
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
help='Number of threads to use for running tests. 0 means run everything '
'in the main thread. Using > 1 thread is experimental.'
)
# We sanitize test names to ensure they work with "unitttest -k" and
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
@ -1074,165 +1061,6 @@ def sample_product(*args, **kw):
"""
return parameterized.parameters(*sample_product_testcases(*args, **kw))
# We use a reader-writer lock to protect test execution. Tests that may run in
# parallel acquire a read lock; tests that are not thread-safe acquire a write
# lock.
_test_rwlock = util.Mutex()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
if getattr(test.__class__, "thread_hostile", False):
_test_rwlock.writer_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager
def thread_unsafe_test():
"""Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0:
yield
return
_test_rwlock.assert_reader_held()
_test_rwlock.reader_unlock()
_test_rwlock.writer_lock()
try:
yield
finally:
_test_rwlock.writer_unlock()
_test_rwlock.reader_lock()
def thread_unsafe_test_class():
"Decorator that marks a TestCase class as thread-hostile."
def f(klass):
assert issubclass(klass, unittest.TestCase), type(klass)
klass.thread_hostile = True
return klass
return f
class ThreadSafeTestResult:
"""
Wraps a TestResult to make it thread safe.
We do this by accumulating API calls and applying them in a batch under a
lock at the conclusion of each test case.
We duck type instead of inheriting from TestResult because we aren't actually
a perfect implementation of TestResult, and would rather get a loud error
for things we haven't implemented.
"""
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
self.lock = lock
self.test_result = result
self.actions: list[Callable] = []
def startTest(self, test: unittest.TestCase):
del test
self.start_time = time.time()
def stopTest(self, test: unittest.TestCase):
stop_time = time.time()
with self.lock:
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
# the time. This affects the timing that shows up in the XML output
# consumed by CI.
time_getter = getattr(self.test_result, "time_getter", None)
try:
self.test_result.time_getter = lambda: self.start_time
self.test_result.startTest(test)
for callback in self.actions:
callback()
self.test_result.time_getter = lambda: stop_time
self.test_result.stopTest(test)
finally:
if time_getter is not None:
self.test_result.time_getter = time_getter
def addSuccess(self, test: unittest.TestCase):
self.actions.append(lambda: self.test_result.addSuccess(test))
def addSkip(self, test: unittest.TestCase, reason: str):
self.actions.append(lambda: self.test_result.addSkip(test, reason))
def addError(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addError(test, err))
def addFailure(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addFailure(test, err))
def addExpectedFailure(self, test: unittest.TestCase, err):
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
def addDuration(self, test: unittest.TestCase, elapsed):
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
class JaxTestSuite(unittest.TestSuite):
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
Caution: this test suite does not run setUpClass or setUpModule methods if
thread parallelism is enabled.
"""
def __init__(self, suite: unittest.TestSuite):
super().__init__(list(suite))
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
if TEST_NUM_THREADS.value <= 0:
return super().run(result)
test_warning_util.install_threadsafe_warning_handlers()
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
lock = threading.Lock()
futures = []
def run_test(test):
"Recursively runs tests in a test suite or test case."
if isinstance(test, unittest.TestSuite):
for subtest in test:
run_test(subtest)
else:
test_result = ThreadSafeTestResult(lock, result)
futures.append(executor.submit(_run_one_test, test, test_result))
with executor:
run_test(self)
for future in futures:
future.result()
return result
class JaxTestLoader(absltest.TestLoader):
suiteClass = JaxTestSuite
def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass)
if _TEST_TARGETS.value:
pattern = re.compile(_TEST_TARGETS.value)
names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")]
if _EXCLUDE_TEST_TARGETS.value:
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
names = [name for name in names
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
return names
def with_config(**kwds):
"""Test case decorator for subclasses of JaxTestCase"""