mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism. If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally. PiperOrigin-RevId: 713296722
This commit is contained in:
parent
f1f98afee8
commit
3fa557289a
@ -1082,10 +1082,8 @@ class JaxTestCase(parameterized.TestCase):
|
||||
'jax_legacy_prng_key': 'error',
|
||||
}
|
||||
|
||||
_compilation_cache_exit_stack: ExitStack | None = None
|
||||
_context_stack: ExitStack | None = None
|
||||
|
||||
def tearDown(self) -> None:
|
||||
assert core.reset_trace_state()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -1096,11 +1094,12 @@ class JaxTestCase(parameterized.TestCase):
|
||||
# b) it returns values in int32 range, which RandomState requires.
|
||||
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._compilation_cache_exit_stack = ExitStack()
|
||||
stack = cls._compilation_cache_exit_stack
|
||||
stack.enter_context(global_config_context(**cls._default_config))
|
||||
# TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum
|
||||
# version.
|
||||
self._context_stack = ExitStack()
|
||||
self.addCleanup(self._context_stack.close)
|
||||
stack = self._context_stack
|
||||
stack.enter_context(global_config_context(**self._default_config))
|
||||
|
||||
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||
stack.enter_context(config.enable_compilation_cache(True))
|
||||
@ -1109,12 +1108,12 @@ class JaxTestCase(parameterized.TestCase):
|
||||
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
|
||||
|
||||
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
|
||||
compilation_cache.set_cache_dir(tmp_dir)
|
||||
stack.callback(lambda: compilation_cache.reset_cache())
|
||||
stack.enter_context(config.compilation_cache_dir(tmp_dir))
|
||||
stack.callback(compilation_cache.reset_cache)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._compilation_cache_exit_stack.close()
|
||||
def tearDown(self) -> None:
|
||||
assert core.reset_trace_state()
|
||||
super().tearDown()
|
||||
|
||||
def rng(self):
|
||||
return self._rng
|
||||
|
@ -69,18 +69,6 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d
|
||||
|
||||
class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in tf.config.list_logical_devices():
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
raise unittest.SkipTest("Test requires tensorflow")
|
||||
@ -88,6 +76,13 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
# One TF device of each device_type
|
||||
self.tf_devices = []
|
||||
for tf_device in tf.config.list_logical_devices():
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in self.tf_devices):
|
||||
self.tf_devices.append(tf_device)
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||
@ -798,7 +793,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
jax_and_tf_platforms = (
|
||||
set(jax_platforms) & {d.device_type.lower()
|
||||
for d in self.__class__.tf_devices})
|
||||
for d in self.tf_devices})
|
||||
|
||||
lowering_platforms = ("tpu", "cpu", "cuda")
|
||||
|
||||
@ -833,7 +828,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=lowering_platforms))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
for tf_device in self.tf_devices:
|
||||
with self.subTest(tf_device.device_type):
|
||||
logging.info(
|
||||
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
||||
|
@ -50,22 +50,17 @@ config.parse_flags_with_absl()
|
||||
|
||||
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
self.tf_devices = []
|
||||
for tf_device in (tf.config.list_logical_devices("TPU") +
|
||||
tf.config.list_logical_devices("GPU") +
|
||||
tf.config.list_logical_devices()):
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if all(tf_device.device_type != d.device_type for d in self.tf_devices):
|
||||
self.tf_devices.append(tf_device)
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||
)
|
||||
@ -1666,7 +1661,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=("cpu", "cuda", "tpu"))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
for tf_device in self.tf_devices:
|
||||
logging.info(
|
||||
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
||||
with tf.device(tf_device):
|
||||
|
@ -25,6 +25,7 @@ import re
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
from absl import app
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
@ -53,7 +54,8 @@ from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
topology = None
|
||||
|
||||
def setUpModule():
|
||||
|
||||
def initialize_tf_tpu():
|
||||
global topology
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
with jtu.ignore_warning(message="the imp module is deprecated"):
|
||||
@ -64,6 +66,7 @@ def setUpModule():
|
||||
else:
|
||||
topology = None
|
||||
|
||||
app.call_after_init(initialize_tf_tpu)
|
||||
|
||||
|
||||
class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -52,18 +52,12 @@ config.parse_flags_with_absl()
|
||||
FAKE_COMPILE_TIME = 10
|
||||
_counts = Counter() # Map event name to count
|
||||
|
||||
|
||||
def setUpModule():
|
||||
monitoring.register_event_listener(increment_event_count)
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
monitoring._unregister_event_listener_by_callback(increment_event_count)
|
||||
|
||||
|
||||
def increment_event_count(event):
|
||||
_counts[event] += 1
|
||||
|
||||
monitoring.register_event_listener(increment_event_count)
|
||||
|
||||
|
||||
def msg_exists_in_logs(msg: str, records: list[logging.LogRecord],
|
||||
level: int | None = None) -> bool:
|
||||
return any(msg in record.getMessage() for record in records
|
||||
|
@ -1487,6 +1487,7 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
|
||||
class JumbleTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jax.config.x64_enabled: raise unittest.SkipTest()
|
||||
|
||||
@parameterized.parameters((True,), (False,))
|
||||
|
@ -48,11 +48,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
|
||||
|
||||
class PrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Pick one device from each available platform
|
||||
cls.devices = []
|
||||
cls.platforms = []
|
||||
self.devices = []
|
||||
self.platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
@ -60,10 +60,9 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
devices = []
|
||||
|
||||
for d in devices:
|
||||
if d.platform not in cls.platforms:
|
||||
cls.platforms.append(d.platform)
|
||||
cls.devices.append(d)
|
||||
super().setUpClass()
|
||||
if d.platform not in self.platforms:
|
||||
self.platforms.append(d.platform)
|
||||
self.devices.append(d)
|
||||
|
||||
# For each primitive we export for all platforms that are available and
|
||||
# compare the results of running the exported code and running the native
|
||||
@ -128,7 +127,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
tol: float | None = None):
|
||||
devices = [
|
||||
d
|
||||
for d in self.__class__.devices
|
||||
for d in self.devices
|
||||
if d.platform not in unimplemented_platforms
|
||||
]
|
||||
logging.info("Using devices %s", [str(d) for d in devices])
|
||||
|
@ -159,17 +159,16 @@ def get_exported(fun: Callable, vjp_order=0,
|
||||
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Find the available platforms
|
||||
cls.platforms = []
|
||||
self.platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
jax.devices(backend)
|
||||
except RuntimeError:
|
||||
continue
|
||||
cls.platforms.append(backend)
|
||||
super().setUpClass()
|
||||
self.platforms.append(backend)
|
||||
|
||||
def test_basic_export_only(self):
|
||||
@jax.jit
|
||||
@ -1499,7 +1498,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
module_str)
|
||||
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
for platform in self.platforms:
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = exp.call(x_device)
|
||||
self.assertAllClose(
|
||||
@ -1524,7 +1523,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertEqual(1, count_sine)
|
||||
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
for platform in self.platforms:
|
||||
if platform == "tpu": continue
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = exp2.call(x_device)
|
||||
@ -1668,7 +1667,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a)
|
||||
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
for platform in self.platforms:
|
||||
run_devices = jax.devices(platform)[0:len(export_devices)]
|
||||
if len(run_devices) != len(export_devices):
|
||||
continue
|
||||
|
@ -32,9 +32,9 @@ NUM_SHARDS = 4
|
||||
class MockGPUTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Mocking devices only works on the GPU backend.")
|
||||
super().setUp()
|
||||
|
||||
@jtu.skip_under_pytest("Test must run in an isolated process")
|
||||
def testMockDeviceCount(self):
|
||||
|
@ -31,9 +31,9 @@ NUM_HOSTS_PER_SLICE = 4
|
||||
class MockGPUTopologyTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Mocking devices only works on the GPU backend.")
|
||||
super().setUp()
|
||||
|
||||
@jtu.skip_under_pytest("Test must run in an isolated process")
|
||||
def testMockDeviceCount(self):
|
||||
|
@ -171,7 +171,7 @@ class TestCase(parameterized.TestCase):
|
||||
self.context = mlir.make_ir_context()
|
||||
if mgpu_dialect is not None:
|
||||
mgpu_dialect.register_dialect(self.context)
|
||||
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
|
||||
self.enter_context(config.traceback_filtering("off"))
|
||||
self.enter_context(self.context)
|
||||
self.enter_context(ir.Location.unknown())
|
||||
|
||||
@ -1756,13 +1756,13 @@ class ProfilerTest(TestCase):
|
||||
|
||||
class TorchTest(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("Test requires PyTorch")
|
||||
cls.torch = torch
|
||||
self.torch = torch
|
||||
|
||||
def test_basic(self):
|
||||
def kernel(ctx, i_gmem, o_gmem, _):
|
||||
|
Loading…
x
Reference in New Issue
Block a user