From 382285d315e8f38447b8b56279d1e38e99496c25 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 9 Apr 2025 18:44:41 -0700 Subject: [PATCH] Split JaxTestLoader and related classes into a separate file. Refactoring only, no functional changes intended. PiperOrigin-RevId: 745813442 --- jax/BUILD | 1 + jax/_src/test_loader.py | 218 ++++++++++++++++++++++++++++++++++++++++ jax/_src/test_util.py | 182 +-------------------------------- 3 files changed, 224 insertions(+), 177 deletions(-) create mode 100644 jax/_src/test_loader.py diff --git a/jax/BUILD b/jax/BUILD index 9891cb556..cb9a39efb 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -142,6 +142,7 @@ py_library( # these are available in jax.test_util via the standard :jax target. name = "test_util", srcs = [ + "_src/test_loader.py", "_src/test_util.py", "_src/test_warning_util.py", ], diff --git a/jax/_src/test_loader.py b/jax/_src/test_loader.py new file mode 100644 index 000000000..d3c44a45e --- /dev/null +++ b/jax/_src/test_loader.py @@ -0,0 +1,218 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Contains a custom unittest loader and test suite. + +Implements: +- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS + environment variables. +- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS + is >= 1. +- Test decorators that mark a test case or test class as thread-hostile. +""" + +from __future__ import annotations + +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +import os +import re +import threading +import time +import unittest + +from absl.testing import absltest +from jax._src import config +from jax._src import test_warning_util +from jax._src import util + + +_TEST_TARGETS = config.string_flag( + 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), + 'Regular expression specifying which tests to run, called via re.search on ' + 'the test name. If empty or unspecified, run all tests.' +) + +_EXCLUDE_TEST_TARGETS = config.string_flag( + 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), + 'Regular expression specifying which tests NOT to run, called via re.search ' + 'on the test name. If empty or unspecified, run all tests.' +) + +TEST_NUM_THREADS = config.int_flag( + 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), + help='Number of threads to use for running tests. 0 means run everything ' + 'in the main thread. Using > 1 thread is experimental.' +) + +# We use a reader-writer lock to protect test execution. Tests that may run in +# parallel acquire a read lock; tests that are not thread-safe acquire a write +# lock. +_test_rwlock = util.Mutex() + +def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): + if getattr(test.__class__, "thread_hostile", False): + _test_rwlock.writer_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.writer_unlock() + else: + _test_rwlock.reader_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.reader_unlock() + + +@contextmanager +def thread_unsafe_test(): + """Decorator for tests that are not thread-safe. + + Note: this decorator (naturally) only applies to what it wraps, not to, say, + code in separate setUp() or tearDown() methods. + """ + if TEST_NUM_THREADS.value <= 0: + yield + return + + _test_rwlock.assert_reader_held() + _test_rwlock.reader_unlock() + _test_rwlock.writer_lock() + try: + yield + finally: + _test_rwlock.writer_unlock() + _test_rwlock.reader_lock() + + +def thread_unsafe_test_class(): + """Decorator that marks a TestCase class as thread-hostile.""" + def f(klass): + assert issubclass(klass, unittest.TestCase), type(klass) + klass.thread_hostile = True + return klass + return f + + +class ThreadSafeTestResult: + """ + Wraps a TestResult to make it thread safe. + + We do this by accumulating API calls and applying them in a batch under a + lock at the conclusion of each test case. + + We duck type instead of inheriting from TestResult because we aren't actually + a perfect implementation of TestResult, and would rather get a loud error + for things we haven't implemented. + """ + def __init__(self, lock: threading.Lock, result: unittest.TestResult): + self.lock = lock + self.test_result = result + self.actions: list[Callable[[], None]] = [] + + def startTest(self, test: unittest.TestCase): + del test + self.start_time = time.time() + + def stopTest(self, test: unittest.TestCase): + stop_time = time.time() + with self.lock: + # If test_result is an ABSL _TextAndXMLTestResult we override how it gets + # the time. This affects the timing that shows up in the XML output + # consumed by CI. + time_getter = getattr(self.test_result, "time_getter", None) + try: + self.test_result.time_getter = lambda: self.start_time + self.test_result.startTest(test) + for callback in self.actions: + callback() + self.test_result.time_getter = lambda: stop_time + self.test_result.stopTest(test) + finally: + if time_getter is not None: + self.test_result.time_getter = time_getter + + def addSuccess(self, test: unittest.TestCase): + self.actions.append(lambda: self.test_result.addSuccess(test)) + + def addSkip(self, test: unittest.TestCase, reason: str): + self.actions.append(lambda: self.test_result.addSkip(test, reason)) + + def addError(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addError(test, err)) + + def addFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addFailure(test, err)) + + def addExpectedFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) + + def addDuration(self, test: unittest.TestCase, elapsed): + self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) + + +class JaxTestSuite(unittest.TestSuite): + """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. + + Caution: this test suite does not run setUpClass or setUpModule methods if + thread parallelism is enabled. + """ + + def __init__(self, suite: unittest.TestSuite): + super().__init__(list(suite)) + + def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: + if TEST_NUM_THREADS.value <= 0: + return super().run(result) + + test_warning_util.install_threadsafe_warning_handlers() + + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) + lock = threading.Lock() + futures = [] + + def run_test(test): + """Recursively runs tests in a test suite or test case.""" + if isinstance(test, unittest.TestSuite): + for subtest in test: + run_test(subtest) + else: + test_result = ThreadSafeTestResult(lock, result) + futures.append(executor.submit(_run_one_test, test, test_result)) + + with executor: + run_test(self) + for future in futures: + future.result() + + return result + + +class JaxTestLoader(absltest.TestLoader): + suiteClass = JaxTestSuite + + def getTestCaseNames(self, testCaseClass): + names = super().getTestCaseNames(testCaseClass) + if _TEST_TARGETS.value: + pattern = re.compile(_TEST_TARGETS.value) + names = [name for name in names + if pattern.search(f"{testCaseClass.__name__}.{name}")] + if _EXCLUDE_TEST_TARGETS.value: + pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) + names = [name for name in names + if not pattern.search(f"{testCaseClass.__name__}.{name}")] + return names diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7f08a14bc..c493d8297 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -17,7 +17,6 @@ from __future__ import annotations import collections from collections.abc import Callable, Generator, Iterable, Sequence -from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack, contextmanager import datetime import functools @@ -32,12 +31,10 @@ import sys import tempfile import textwrap import threading -import time from typing import Any, TextIO import unittest import zlib -from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax @@ -63,12 +60,17 @@ from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict) +from jax._src.test_loader import thread_unsafe_test as thread_unsafe_test +from jax._src.test_loader import thread_unsafe_test_class as thread_unsafe_test_class +from jax._src.test_loader import JaxTestLoader as JaxTestLoader +from jax._src.test_loader import TEST_NUM_THREADS as TEST_NUM_THREADS from jax._src.util import unzip2 from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten import numpy as np import numpy.random as npr + # This submodule includes private test utilities that are not exported to # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. @@ -98,16 +100,6 @@ SKIP_SLOW_TESTS = config.bool_flag( help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.string_flag( - 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), - 'Regular expression specifying which tests to run, called via re.search on ' - 'the test name. If empty or unspecified, run all tests.' -) -_EXCLUDE_TEST_TARGETS = config.string_flag( - 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), - 'Regular expression specifying which tests NOT to run, called via re.search ' - 'on the test name. If empty or unspecified, run all tests.' -) TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), @@ -121,11 +113,6 @@ HYPOTHESIS_PROFILE = config.string_flag( 'deterministic, interactive'), ) -TEST_NUM_THREADS = config.int_flag( - 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), - help='Number of threads to use for running tests. 0 means run everything ' - 'in the main thread. Using > 1 thread is experimental.' -) # We sanitize test names to ensure they work with "unitttest -k" and # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k @@ -1074,165 +1061,6 @@ def sample_product(*args, **kw): """ return parameterized.parameters(*sample_product_testcases(*args, **kw)) -# We use a reader-writer lock to protect test execution. Tests that may run in -# parallel acquire a read lock; tests that are not thread-safe acquire a write -# lock. -_test_rwlock = util.Mutex() - -def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): - if getattr(test.__class__, "thread_hostile", False): - _test_rwlock.writer_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.writer_unlock() - else: - _test_rwlock.reader_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.reader_unlock() - - -@contextmanager -def thread_unsafe_test(): - """Decorator for tests that are not thread-safe. - - Note: this decorator (naturally) only applies to what it wraps, not to, say, - code in separate setUp() or tearDown() methods. - """ - if TEST_NUM_THREADS.value <= 0: - yield - return - - _test_rwlock.assert_reader_held() - _test_rwlock.reader_unlock() - _test_rwlock.writer_lock() - try: - yield - finally: - _test_rwlock.writer_unlock() - _test_rwlock.reader_lock() - - -def thread_unsafe_test_class(): - "Decorator that marks a TestCase class as thread-hostile." - def f(klass): - assert issubclass(klass, unittest.TestCase), type(klass) - klass.thread_hostile = True - return klass - return f - - -class ThreadSafeTestResult: - """ - Wraps a TestResult to make it thread safe. - - We do this by accumulating API calls and applying them in a batch under a - lock at the conclusion of each test case. - - We duck type instead of inheriting from TestResult because we aren't actually - a perfect implementation of TestResult, and would rather get a loud error - for things we haven't implemented. - """ - def __init__(self, lock: threading.Lock, result: unittest.TestResult): - self.lock = lock - self.test_result = result - self.actions: list[Callable] = [] - - def startTest(self, test: unittest.TestCase): - del test - self.start_time = time.time() - - def stopTest(self, test: unittest.TestCase): - stop_time = time.time() - with self.lock: - # If test_result is an ABSL _TextAndXMLTestResult we override how it gets - # the time. This affects the timing that shows up in the XML output - # consumed by CI. - time_getter = getattr(self.test_result, "time_getter", None) - try: - self.test_result.time_getter = lambda: self.start_time - self.test_result.startTest(test) - for callback in self.actions: - callback() - self.test_result.time_getter = lambda: stop_time - self.test_result.stopTest(test) - finally: - if time_getter is not None: - self.test_result.time_getter = time_getter - - def addSuccess(self, test: unittest.TestCase): - self.actions.append(lambda: self.test_result.addSuccess(test)) - - def addSkip(self, test: unittest.TestCase, reason: str): - self.actions.append(lambda: self.test_result.addSkip(test, reason)) - - def addError(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addError(test, err)) - - def addFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addFailure(test, err)) - - def addExpectedFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) - - def addDuration(self, test: unittest.TestCase, elapsed): - self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) - - -class JaxTestSuite(unittest.TestSuite): - """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. - - Caution: this test suite does not run setUpClass or setUpModule methods if - thread parallelism is enabled. - """ - - def __init__(self, suite: unittest.TestSuite): - super().__init__(list(suite)) - - def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: - if TEST_NUM_THREADS.value <= 0: - return super().run(result) - - test_warning_util.install_threadsafe_warning_handlers() - - executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) - lock = threading.Lock() - futures = [] - - def run_test(test): - "Recursively runs tests in a test suite or test case." - if isinstance(test, unittest.TestSuite): - for subtest in test: - run_test(subtest) - else: - test_result = ThreadSafeTestResult(lock, result) - futures.append(executor.submit(_run_one_test, test, test_result)) - - with executor: - run_test(self) - for future in futures: - future.result() - - return result - - -class JaxTestLoader(absltest.TestLoader): - suiteClass = JaxTestSuite - - def getTestCaseNames(self, testCaseClass): - names = super().getTestCaseNames(testCaseClass) - if _TEST_TARGETS.value: - pattern = re.compile(_TEST_TARGETS.value) - names = [name for name in names - if pattern.search(f"{testCaseClass.__name__}.{name}")] - if _EXCLUDE_TEST_TARGETS.value: - pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) - names = [name for name in names - if not pattern.search(f"{testCaseClass.__name__}.{name}")] - return names - def with_config(**kwds): """Test case decorator for subclasses of JaxTestCase"""