diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 46c442d63..ae898148e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1082,10 +1082,8 @@ class JaxTestCase(parameterized.TestCase): 'jax_legacy_prng_key': 'error', } - _compilation_cache_exit_stack: ExitStack | None = None + _context_stack: ExitStack | None = None - def tearDown(self) -> None: - assert core.reset_trace_state() def setUp(self): super().setUp() @@ -1096,11 +1094,12 @@ class JaxTestCase(parameterized.TestCase): # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - @classmethod - def setUpClass(cls): - cls._compilation_cache_exit_stack = ExitStack() - stack = cls._compilation_cache_exit_stack - stack.enter_context(global_config_context(**cls._default_config)) + # TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum + # version. + self._context_stack = ExitStack() + self.addCleanup(self._context_stack.close) + stack = self._context_stack + stack.enter_context(global_config_context(**self._default_config)) if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: stack.enter_context(config.enable_compilation_cache(True)) @@ -1109,12 +1108,12 @@ class JaxTestCase(parameterized.TestCase): stack.enter_context(config.persistent_cache_min_entry_size_bytes(0)) tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) - compilation_cache.set_cache_dir(tmp_dir) - stack.callback(lambda: compilation_cache.reset_cache()) + stack.enter_context(config.compilation_cache_dir(tmp_dir)) + stack.callback(compilation_cache.reset_cache) - @classmethod - def tearDownClass(cls): - cls._compilation_cache_exit_stack.close() + def tearDown(self) -> None: + assert core.reset_trace_state() + super().tearDown() def rng(self): return self._rng diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index f23bd58c4..0cde96aeb 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -69,18 +69,6 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d class CallTfTest(tf_test_util.JaxToTfTestCase): - @classmethod - def setUpClass(cls): - # One TF device of each device_type - cls.tf_devices = [] - for tf_device in tf.config.list_logical_devices(): - if tf_device.device_type == "TPU_SYSTEM": - continue # A virtual device - if all(tf_device.device_type != d.device_type for d in cls.tf_devices): - cls.tf_devices.append(tf_device) - - super().setUpClass() - def setUp(self): if tf is None: raise unittest.SkipTest("Test requires tensorflow") @@ -88,6 +76,13 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + # One TF device of each device_type + self.tf_devices = [] + for tf_device in tf.config.list_logical_devices(): + if tf_device.device_type == "TPU_SYSTEM": + continue # A virtual device + if all(tf_device.device_type != d.device_type for d in self.tf_devices): + self.tf_devices.append(tf_device) self.warning_ctx = jtu.ignore_warning( message=( "(jax2tf.convert with native_serialization=False has been deprecated" @@ -798,7 +793,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): jax_and_tf_platforms = ( set(jax_platforms) & {d.device_type.lower() - for d in self.__class__.tf_devices}) + for d in self.tf_devices}) lowering_platforms = ("tpu", "cpu", "cuda") @@ -833,7 +828,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): f_jax, native_serialization=True, native_serialization_platforms=lowering_platforms)) - for tf_device in self.__class__.tf_devices: + for tf_device in self.tf_devices: with self.subTest(tf_device.device_type): logging.info( f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}") diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 7d3313be6..bea2b76cb 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -50,22 +50,17 @@ config.parse_flags_with_absl() class Jax2TfTest(tf_test_util.JaxToTfTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # One TF device of each device_type - cls.tf_devices = [] + self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + tf.config.list_logical_devices("GPU") + tf.config.list_logical_devices()): if tf_device.device_type == "TPU_SYSTEM": continue # A virtual device - if all(tf_device.device_type != d.device_type for d in cls.tf_devices): - cls.tf_devices.append(tf_device) - - super().setUpClass() - - def setUp(self): - super().setUp() + if all(tf_device.device_type != d.device_type for d in self.tf_devices): + self.tf_devices.append(tf_device) self.warning_ctx = jtu.ignore_warning( message="jax2tf.convert with native_serialization=False has been deprecated" ) @@ -1666,7 +1661,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): f_jax, native_serialization=True, native_serialization_platforms=("cpu", "cuda", "tpu")) - for tf_device in self.__class__.tf_devices: + for tf_device in self.tf_devices: logging.info( f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}") with tf.device(tf_device): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 05d2352a4..653ddce7d 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -25,6 +25,7 @@ import re from typing import Any import unittest +from absl import app from absl.testing import absltest import jax @@ -53,7 +54,8 @@ from jax.experimental.jax2tf.tests import tf_test_util topology = None -def setUpModule(): + +def initialize_tf_tpu(): global topology if jtu.test_device_matches(["tpu"]): with jtu.ignore_warning(message="the imp module is deprecated"): @@ -64,6 +66,7 @@ def setUpModule(): else: topology = None +app.call_after_init(initialize_tf_tpu) class ShardingTest(tf_test_util.JaxToTfTestCase): diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 428e518ea..73d76c1a4 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -52,18 +52,12 @@ config.parse_flags_with_absl() FAKE_COMPILE_TIME = 10 _counts = Counter() # Map event name to count - -def setUpModule(): - monitoring.register_event_listener(increment_event_count) - - -def tearDownModule(): - monitoring._unregister_event_listener_by_callback(increment_event_count) - - def increment_event_count(event): _counts[event] += 1 +monitoring.register_event_listener(increment_event_count) + + def msg_exists_in_logs(msg: str, records: list[logging.LogRecord], level: int | None = None) -> bool: return any(msg in record.getMessage() for record in records diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index f6625e86c..df79e6aaf 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1487,6 +1487,7 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase): class JumbleTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if jax.config.x64_enabled: raise unittest.SkipTest() @parameterized.parameters((True,), (False,)) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index e8b1afc22..d5878fa50 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -48,11 +48,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: class PrimitiveTest(jtu.JaxTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # Pick one device from each available platform - cls.devices = [] - cls.platforms = [] + self.devices = [] + self.platforms = [] for backend in ["cpu", "gpu", "tpu"]: try: devices = jax.devices(backend) @@ -60,10 +60,9 @@ class PrimitiveTest(jtu.JaxTestCase): devices = [] for d in devices: - if d.platform not in cls.platforms: - cls.platforms.append(d.platform) - cls.devices.append(d) - super().setUpClass() + if d.platform not in self.platforms: + self.platforms.append(d.platform) + self.devices.append(d) # For each primitive we export for all platforms that are available and # compare the results of running the exported code and running the native @@ -128,7 +127,7 @@ class PrimitiveTest(jtu.JaxTestCase): tol: float | None = None): devices = [ d - for d in self.__class__.devices + for d in self.devices if d.platform not in unimplemented_platforms ] logging.info("Using devices %s", [str(d) for d in devices]) diff --git a/tests/export_test.py b/tests/export_test.py index 63fe4a8bc..b13cf3a62 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -159,17 +159,16 @@ def get_exported(fun: Callable, vjp_order=0, @jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version) class JaxExportTest(jtu.JaxTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # Find the available platforms - cls.platforms = [] + self.platforms = [] for backend in ["cpu", "gpu", "tpu"]: try: jax.devices(backend) except RuntimeError: continue - cls.platforms.append(backend) - super().setUpClass() + self.platforms.append(backend) def test_basic_export_only(self): @jax.jit @@ -1499,7 +1498,7 @@ class JaxExportTest(jtu.JaxTestCase): module_str) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp.call(x_device) self.assertAllClose( @@ -1524,7 +1523,7 @@ class JaxExportTest(jtu.JaxTestCase): self.assertEqual(1, count_sine) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp2.call(x_device) @@ -1668,7 +1667,7 @@ class JaxExportTest(jtu.JaxTestCase): exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: run_devices = jax.devices(platform)[0:len(export_devices)] if len(run_devices) != len(export_devices): continue diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index b84903618..7fb87086d 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -32,9 +32,9 @@ NUM_SHARDS = 4 class MockGPUTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest("Mocking devices only works on the GPU backend.") - super().setUp() @jtu.skip_under_pytest("Test must run in an isolated process") def testMockDeviceCount(self): diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 44ec4e2f9..71ce8f1dd 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -31,9 +31,9 @@ NUM_HOSTS_PER_SLICE = 4 class MockGPUTopologyTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest("Mocking devices only works on the GPU backend.") - super().setUp() @jtu.skip_under_pytest("Test must run in an isolated process") def testMockDeviceCount(self): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index dd41b264a..d8ae6d3c2 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -171,7 +171,7 @@ class TestCase(parameterized.TestCase): self.context = mlir.make_ir_context() if mgpu_dialect is not None: mgpu_dialect.register_dialect(self.context) - self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) + self.enter_context(config.traceback_filtering("off")) self.enter_context(self.context) self.enter_context(ir.Location.unknown()) @@ -1756,13 +1756,13 @@ class ProfilerTest(TestCase): class TorchTest(TestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() try: import torch except ImportError: raise unittest.SkipTest("Test requires PyTorch") - cls.torch = torch + self.torch = torch def test_basic(self): def kernel(ctx, i_gmem, o_gmem, _):