From 3fa557289a39c1207160f76dcfe70fbc3320009f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 8 Jan 2025 08:14:12 -0800 Subject: [PATCH] Port tests away from setUpClass and setUpModule to setUp alone. This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism. If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally. PiperOrigin-RevId: 713296722 --- jax/_src/test_util.py | 25 +++++++++---------- jax/experimental/jax2tf/tests/call_tf_test.py | 23 +++++++---------- jax/experimental/jax2tf/tests/jax2tf_test.py | 17 +++++-------- .../jax2tf/tests/sharding_test.py | 5 +++- tests/compilation_cache_test.py | 12 +++------ tests/dynamic_api_test.py | 1 + tests/export_harnesses_multi_platform_test.py | 17 ++++++------- tests/export_test.py | 15 ++++++----- tests/mock_gpu_test.py | 2 +- tests/mock_gpu_topology_test.py | 2 +- tests/mosaic/gpu_test.py | 8 +++--- 11 files changed, 56 insertions(+), 71 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 46c442d63..ae898148e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1082,10 +1082,8 @@ class JaxTestCase(parameterized.TestCase): 'jax_legacy_prng_key': 'error', } - _compilation_cache_exit_stack: ExitStack | None = None + _context_stack: ExitStack | None = None - def tearDown(self) -> None: - assert core.reset_trace_state() def setUp(self): super().setUp() @@ -1096,11 +1094,12 @@ class JaxTestCase(parameterized.TestCase): # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - @classmethod - def setUpClass(cls): - cls._compilation_cache_exit_stack = ExitStack() - stack = cls._compilation_cache_exit_stack - stack.enter_context(global_config_context(**cls._default_config)) + # TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum + # version. + self._context_stack = ExitStack() + self.addCleanup(self._context_stack.close) + stack = self._context_stack + stack.enter_context(global_config_context(**self._default_config)) if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: stack.enter_context(config.enable_compilation_cache(True)) @@ -1109,12 +1108,12 @@ class JaxTestCase(parameterized.TestCase): stack.enter_context(config.persistent_cache_min_entry_size_bytes(0)) tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) - compilation_cache.set_cache_dir(tmp_dir) - stack.callback(lambda: compilation_cache.reset_cache()) + stack.enter_context(config.compilation_cache_dir(tmp_dir)) + stack.callback(compilation_cache.reset_cache) - @classmethod - def tearDownClass(cls): - cls._compilation_cache_exit_stack.close() + def tearDown(self) -> None: + assert core.reset_trace_state() + super().tearDown() def rng(self): return self._rng diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index f23bd58c4..0cde96aeb 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -69,18 +69,6 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d class CallTfTest(tf_test_util.JaxToTfTestCase): - @classmethod - def setUpClass(cls): - # One TF device of each device_type - cls.tf_devices = [] - for tf_device in tf.config.list_logical_devices(): - if tf_device.device_type == "TPU_SYSTEM": - continue # A virtual device - if all(tf_device.device_type != d.device_type for d in cls.tf_devices): - cls.tf_devices.append(tf_device) - - super().setUpClass() - def setUp(self): if tf is None: raise unittest.SkipTest("Test requires tensorflow") @@ -88,6 +76,13 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + # One TF device of each device_type + self.tf_devices = [] + for tf_device in tf.config.list_logical_devices(): + if tf_device.device_type == "TPU_SYSTEM": + continue # A virtual device + if all(tf_device.device_type != d.device_type for d in self.tf_devices): + self.tf_devices.append(tf_device) self.warning_ctx = jtu.ignore_warning( message=( "(jax2tf.convert with native_serialization=False has been deprecated" @@ -798,7 +793,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): jax_and_tf_platforms = ( set(jax_platforms) & {d.device_type.lower() - for d in self.__class__.tf_devices}) + for d in self.tf_devices}) lowering_platforms = ("tpu", "cpu", "cuda") @@ -833,7 +828,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): f_jax, native_serialization=True, native_serialization_platforms=lowering_platforms)) - for tf_device in self.__class__.tf_devices: + for tf_device in self.tf_devices: with self.subTest(tf_device.device_type): logging.info( f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}") diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 7d3313be6..bea2b76cb 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -50,22 +50,17 @@ config.parse_flags_with_absl() class Jax2TfTest(tf_test_util.JaxToTfTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # One TF device of each device_type - cls.tf_devices = [] + self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + tf.config.list_logical_devices("GPU") + tf.config.list_logical_devices()): if tf_device.device_type == "TPU_SYSTEM": continue # A virtual device - if all(tf_device.device_type != d.device_type for d in cls.tf_devices): - cls.tf_devices.append(tf_device) - - super().setUpClass() - - def setUp(self): - super().setUp() + if all(tf_device.device_type != d.device_type for d in self.tf_devices): + self.tf_devices.append(tf_device) self.warning_ctx = jtu.ignore_warning( message="jax2tf.convert with native_serialization=False has been deprecated" ) @@ -1666,7 +1661,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): f_jax, native_serialization=True, native_serialization_platforms=("cpu", "cuda", "tpu")) - for tf_device in self.__class__.tf_devices: + for tf_device in self.tf_devices: logging.info( f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}") with tf.device(tf_device): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 05d2352a4..653ddce7d 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -25,6 +25,7 @@ import re from typing import Any import unittest +from absl import app from absl.testing import absltest import jax @@ -53,7 +54,8 @@ from jax.experimental.jax2tf.tests import tf_test_util topology = None -def setUpModule(): + +def initialize_tf_tpu(): global topology if jtu.test_device_matches(["tpu"]): with jtu.ignore_warning(message="the imp module is deprecated"): @@ -64,6 +66,7 @@ def setUpModule(): else: topology = None +app.call_after_init(initialize_tf_tpu) class ShardingTest(tf_test_util.JaxToTfTestCase): diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 428e518ea..73d76c1a4 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -52,18 +52,12 @@ config.parse_flags_with_absl() FAKE_COMPILE_TIME = 10 _counts = Counter() # Map event name to count - -def setUpModule(): - monitoring.register_event_listener(increment_event_count) - - -def tearDownModule(): - monitoring._unregister_event_listener_by_callback(increment_event_count) - - def increment_event_count(event): _counts[event] += 1 +monitoring.register_event_listener(increment_event_count) + + def msg_exists_in_logs(msg: str, records: list[logging.LogRecord], level: int | None = None) -> bool: return any(msg in record.getMessage() for record in records diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index f6625e86c..df79e6aaf 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1487,6 +1487,7 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase): class JumbleTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if jax.config.x64_enabled: raise unittest.SkipTest() @parameterized.parameters((True,), (False,)) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index e8b1afc22..d5878fa50 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -48,11 +48,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: class PrimitiveTest(jtu.JaxTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # Pick one device from each available platform - cls.devices = [] - cls.platforms = [] + self.devices = [] + self.platforms = [] for backend in ["cpu", "gpu", "tpu"]: try: devices = jax.devices(backend) @@ -60,10 +60,9 @@ class PrimitiveTest(jtu.JaxTestCase): devices = [] for d in devices: - if d.platform not in cls.platforms: - cls.platforms.append(d.platform) - cls.devices.append(d) - super().setUpClass() + if d.platform not in self.platforms: + self.platforms.append(d.platform) + self.devices.append(d) # For each primitive we export for all platforms that are available and # compare the results of running the exported code and running the native @@ -128,7 +127,7 @@ class PrimitiveTest(jtu.JaxTestCase): tol: float | None = None): devices = [ d - for d in self.__class__.devices + for d in self.devices if d.platform not in unimplemented_platforms ] logging.info("Using devices %s", [str(d) for d in devices]) diff --git a/tests/export_test.py b/tests/export_test.py index 63fe4a8bc..b13cf3a62 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -159,17 +159,16 @@ def get_exported(fun: Callable, vjp_order=0, @jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version) class JaxExportTest(jtu.JaxTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # Find the available platforms - cls.platforms = [] + self.platforms = [] for backend in ["cpu", "gpu", "tpu"]: try: jax.devices(backend) except RuntimeError: continue - cls.platforms.append(backend) - super().setUpClass() + self.platforms.append(backend) def test_basic_export_only(self): @jax.jit @@ -1499,7 +1498,7 @@ class JaxExportTest(jtu.JaxTestCase): module_str) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp.call(x_device) self.assertAllClose( @@ -1524,7 +1523,7 @@ class JaxExportTest(jtu.JaxTestCase): self.assertEqual(1, count_sine) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp2.call(x_device) @@ -1668,7 +1667,7 @@ class JaxExportTest(jtu.JaxTestCase): exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: run_devices = jax.devices(platform)[0:len(export_devices)] if len(run_devices) != len(export_devices): continue diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index b84903618..7fb87086d 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -32,9 +32,9 @@ NUM_SHARDS = 4 class MockGPUTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest("Mocking devices only works on the GPU backend.") - super().setUp() @jtu.skip_under_pytest("Test must run in an isolated process") def testMockDeviceCount(self): diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 44ec4e2f9..71ce8f1dd 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -31,9 +31,9 @@ NUM_HOSTS_PER_SLICE = 4 class MockGPUTopologyTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest("Mocking devices only works on the GPU backend.") - super().setUp() @jtu.skip_under_pytest("Test must run in an isolated process") def testMockDeviceCount(self): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index dd41b264a..d8ae6d3c2 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -171,7 +171,7 @@ class TestCase(parameterized.TestCase): self.context = mlir.make_ir_context() if mgpu_dialect is not None: mgpu_dialect.register_dialect(self.context) - self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) + self.enter_context(config.traceback_filtering("off")) self.enter_context(self.context) self.enter_context(ir.Location.unknown()) @@ -1756,13 +1756,13 @@ class ProfilerTest(TestCase): class TorchTest(TestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() try: import torch except ImportError: raise unittest.SkipTest("Test requires PyTorch") - cls.torch = torch + self.torch = torch def test_basic(self): def kernel(ctx, i_gmem, o_gmem, _):