Port tests away from setUpClass and setUpModule to setUp alone.

This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
This commit is contained in:
Peter Hawkins 2025-01-08 08:14:12 -08:00 committed by jax authors
parent f1f98afee8
commit 3fa557289a
11 changed files with 56 additions and 71 deletions

View File

@ -1082,10 +1082,8 @@ class JaxTestCase(parameterized.TestCase):
'jax_legacy_prng_key': 'error',
}
_compilation_cache_exit_stack: ExitStack | None = None
_context_stack: ExitStack | None = None
def tearDown(self) -> None:
assert core.reset_trace_state()
def setUp(self):
super().setUp()
@ -1096,11 +1094,12 @@ class JaxTestCase(parameterized.TestCase):
# b) it returns values in int32 range, which RandomState requires.
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
@classmethod
def setUpClass(cls):
cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack
stack.enter_context(global_config_context(**cls._default_config))
# TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum
# version.
self._context_stack = ExitStack()
self.addCleanup(self._context_stack.close)
stack = self._context_stack
stack.enter_context(global_config_context(**self._default_config))
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
stack.enter_context(config.enable_compilation_cache(True))
@ -1109,12 +1108,12 @@ class JaxTestCase(parameterized.TestCase):
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
compilation_cache.set_cache_dir(tmp_dir)
stack.callback(lambda: compilation_cache.reset_cache())
stack.enter_context(config.compilation_cache_dir(tmp_dir))
stack.callback(compilation_cache.reset_cache)
@classmethod
def tearDownClass(cls):
cls._compilation_cache_exit_stack.close()
def tearDown(self) -> None:
assert core.reset_trace_state()
super().tearDown()
def rng(self):
return self._rng

View File

@ -69,18 +69,6 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d
class CallTfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# One TF device of each device_type
cls.tf_devices = []
for tf_device in tf.config.list_logical_devices():
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super().setUpClass()
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
@ -88,6 +76,13 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow.
_ = tf.add(1, 1)
super().setUp()
# One TF device of each device_type
self.tf_devices = []
for tf_device in tf.config.list_logical_devices():
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in self.tf_devices):
self.tf_devices.append(tf_device)
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False has been deprecated"
@ -798,7 +793,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
jax_and_tf_platforms = (
set(jax_platforms) & {d.device_type.lower()
for d in self.__class__.tf_devices})
for d in self.tf_devices})
lowering_platforms = ("tpu", "cpu", "cuda")
@ -833,7 +828,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
f_jax,
native_serialization=True,
native_serialization_platforms=lowering_platforms))
for tf_device in self.__class__.tf_devices:
for tf_device in self.tf_devices:
with self.subTest(tf_device.device_type):
logging.info(
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")

View File

@ -50,22 +50,17 @@ config.parse_flags_with_absl()
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
def setUp(self):
super().setUp()
# One TF device of each device_type
cls.tf_devices = []
self.tf_devices = []
for tf_device in (tf.config.list_logical_devices("TPU") +
tf.config.list_logical_devices("GPU") +
tf.config.list_logical_devices()):
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super().setUpClass()
def setUp(self):
super().setUp()
if all(tf_device.device_type != d.device_type for d in self.tf_devices):
self.tf_devices.append(tf_device)
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False has been deprecated"
)
@ -1666,7 +1661,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_jax,
native_serialization=True,
native_serialization_platforms=("cpu", "cuda", "tpu"))
for tf_device in self.__class__.tf_devices:
for tf_device in self.tf_devices:
logging.info(
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
with tf.device(tf_device):

View File

@ -25,6 +25,7 @@ import re
from typing import Any
import unittest
from absl import app
from absl.testing import absltest
import jax
@ -53,7 +54,8 @@ from jax.experimental.jax2tf.tests import tf_test_util
topology = None
def setUpModule():
def initialize_tf_tpu():
global topology
if jtu.test_device_matches(["tpu"]):
with jtu.ignore_warning(message="the imp module is deprecated"):
@ -64,6 +66,7 @@ def setUpModule():
else:
topology = None
app.call_after_init(initialize_tf_tpu)
class ShardingTest(tf_test_util.JaxToTfTestCase):

View File

@ -52,18 +52,12 @@ config.parse_flags_with_absl()
FAKE_COMPILE_TIME = 10
_counts = Counter() # Map event name to count
def setUpModule():
monitoring.register_event_listener(increment_event_count)
def tearDownModule():
monitoring._unregister_event_listener_by_callback(increment_event_count)
def increment_event_count(event):
_counts[event] += 1
monitoring.register_event_listener(increment_event_count)
def msg_exists_in_logs(msg: str, records: list[logging.LogRecord],
level: int | None = None) -> bool:
return any(msg in record.getMessage() for record in records

View File

@ -1487,6 +1487,7 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
class JumbleTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jax.config.x64_enabled: raise unittest.SkipTest()
@parameterized.parameters((True,), (False,))

View File

@ -48,11 +48,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
class PrimitiveTest(jtu.JaxTestCase):
@classmethod
def setUpClass(cls):
def setUp(self):
super().setUp()
# Pick one device from each available platform
cls.devices = []
cls.platforms = []
self.devices = []
self.platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
@ -60,10 +60,9 @@ class PrimitiveTest(jtu.JaxTestCase):
devices = []
for d in devices:
if d.platform not in cls.platforms:
cls.platforms.append(d.platform)
cls.devices.append(d)
super().setUpClass()
if d.platform not in self.platforms:
self.platforms.append(d.platform)
self.devices.append(d)
# For each primitive we export for all platforms that are available and
# compare the results of running the exported code and running the native
@ -128,7 +127,7 @@ class PrimitiveTest(jtu.JaxTestCase):
tol: float | None = None):
devices = [
d
for d in self.__class__.devices
for d in self.devices
if d.platform not in unimplemented_platforms
]
logging.info("Using devices %s", [str(d) for d in devices])

View File

@ -159,17 +159,16 @@ def get_exported(fun: Callable, vjp_order=0,
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
class JaxExportTest(jtu.JaxTestCase):
@classmethod
def setUpClass(cls):
def setUp(self):
super().setUp()
# Find the available platforms
cls.platforms = []
self.platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
jax.devices(backend)
except RuntimeError:
continue
cls.platforms.append(backend)
super().setUpClass()
self.platforms.append(backend)
def test_basic_export_only(self):
@jax.jit
@ -1499,7 +1498,7 @@ class JaxExportTest(jtu.JaxTestCase):
module_str)
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
for platform in self.platforms:
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = exp.call(x_device)
self.assertAllClose(
@ -1524,7 +1523,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(1, count_sine)
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
for platform in self.platforms:
if platform == "tpu": continue
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = exp2.call(x_device)
@ -1668,7 +1667,7 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a)
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
for platform in self.platforms:
run_devices = jax.devices(platform)[0:len(export_devices)]
if len(run_devices) != len(export_devices):
continue

View File

@ -32,9 +32,9 @@ NUM_SHARDS = 4
class MockGPUTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Mocking devices only works on the GPU backend.")
super().setUp()
@jtu.skip_under_pytest("Test must run in an isolated process")
def testMockDeviceCount(self):

View File

@ -31,9 +31,9 @@ NUM_HOSTS_PER_SLICE = 4
class MockGPUTopologyTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Mocking devices only works on the GPU backend.")
super().setUp()
@jtu.skip_under_pytest("Test must run in an isolated process")
def testMockDeviceCount(self):

View File

@ -171,7 +171,7 @@ class TestCase(parameterized.TestCase):
self.context = mlir.make_ir_context()
if mgpu_dialect is not None:
mgpu_dialect.register_dialect(self.context)
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
self.enter_context(config.traceback_filtering("off"))
self.enter_context(self.context)
self.enter_context(ir.Location.unknown())
@ -1756,13 +1756,13 @@ class ProfilerTest(TestCase):
class TorchTest(TestCase):
@classmethod
def setUpClass(cls):
def setUp(self):
super().setUp()
try:
import torch
except ImportError:
raise unittest.SkipTest("Test requires PyTorch")
cls.torch = torch
self.torch = torch
def test_basic(self):
def kernel(ctx, i_gmem, o_gmem, _):