diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a2f608877..46c442d63 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -498,29 +498,20 @@ def device_supports_buffer_donation(): ) -@contextmanager -def set_host_platform_device_count(nr_devices: int): - """Context manager to set host platform device count if not specified by user. +def request_cpu_devices(nr_devices: int): + """Requests at least `nr_devices` CPU devices. - This should only be used by tests at the top level in setUpModule(); it will - not work correctly if applied to individual test cases. + request_cpu_devices should be called at the top-level of a test module before + main() runs. + + It is not guaranteed that the number of CPU devices will be exactly + `nr_devices`: it may be more or less, depending on how exactly the test is + invoked. Test cases that require a specific number of devices should skip + themselves if that number is not met. """ - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - f" --xla_force_host_platform_device_count={nr_devices}") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - try: - yield - finally: - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags + if xla_bridge.NUM_CPU_DEVICES.value < nr_devices: xla_bridge.get_backend.cache_clear() + config.update("jax_num_cpu_devices", nr_devices) def skip_on_flag(flag_name, skip_value): diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 28148761c..bbe663175 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -122,6 +122,14 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( "inline without async dispatch.", ) +NUM_CPU_DEVICES = config.int_flag( + name="jax_num_cpu_devices", + default=-1, + help="Number of CPU devices to use. If not provided, the value of " + "the XLA flag --xla_force_host_platform_device_count is used." + " Must be set before JAX is initialized.", +) + # Warn the user if they call fork(), because it's not going to go well for them. def _at_fork(): @@ -249,8 +257,8 @@ def make_cpu_client( if collectives is None: collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value if _CPU_ENABLE_GLOO_COLLECTIVES.value: - collectives_impl = 'gloo' - warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' + collectives_impl = 'gloo' + warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' 'deprecated. Please use `jax.config.update(' '"jax_cpu_collectives_implementation", "gloo")` instead.', DeprecationWarning, @@ -268,12 +276,22 @@ def make_cpu_client( f"{collectives_impl}. Available implementations are " f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.") + num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None + if xla_client._version < 303 and num_devices is not None: + xla_flags = os.getenv("XLA_FLAGS") or "" + os.environ["XLA_FLAGS"] = ( + f"{xla_flags} --xla_force_host_platform_device_count={num_devices}" + ) + num_devices = None + # TODO(phawkins): pass num_devices directly when version 303 is the minimum. + kwargs = {} if num_devices is None else {"num_devices": num_devices} return xla_client.make_cpu_client( asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, num_nodes=distributed.global_state.num_processes, collectives=collectives, + **kwargs, ) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index a4bb168ef..6ec621d68 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -14,7 +14,6 @@ """Tests for serialization and deserialization of GDA.""" import asyncio -import contextlib import math from functools import partial import os @@ -36,13 +35,7 @@ import numpy as np import tensorstore as ts jax.config.parse_flags_with_absl() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class CheckpointTest(jtu.JaxTestCase): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 8fe9a1dd9..05d2352a4 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -19,7 +19,6 @@ """ from collections.abc import Sequence -import contextlib from functools import partial import logging import re @@ -47,16 +46,14 @@ import numpy as np import tensorflow as tf config.parse_flags_with_absl() +jtu.request_cpu_devices(8) # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util -_exit_stack = contextlib.ExitStack() topology = None def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - global topology if jtu.test_device_matches(["tpu"]): with jtu.ignore_warning(message="the imp module is deprecated"): @@ -67,8 +64,6 @@ def setUpModule(): else: topology = None -def tearDownModule(): - _exit_stack.close() class ShardingTest(tf_test_util.JaxToTfTestCase): diff --git a/tests/array_test.py b/tests/array_test.py index 9618a8cf4..97bf71a52 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -43,20 +43,12 @@ from jax._src import array from jax._src import prng jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) with contextlib.suppress(ImportError): import pytest pytestmark = pytest.mark.multiaccelerator -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() - def create_array(shape, sharding, global_data=None): if global_data is None: diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 4485f5d4f..f9dd3ce52 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import threading import time from typing import Sequence @@ -29,6 +28,7 @@ import jax.numpy as jnp import numpy as np config.parse_flags_with_absl() +jtu.request_cpu_devices(8) def _colocated_cpu_devices( @@ -53,18 +53,6 @@ def _colocated_cpu_devices( _count_colocated_python_specialization_cache_miss = jtu.count_events( "colocated_python_func._get_specialized_func") -_exit_stack = contextlib.ExitStack() - - -def setUpModule(): - # TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT - # prepares CPU devices by its own. - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - - -def tearDownModule(): - _exit_stack.close() - class ColocatedPythonTest(jtu.JaxTestCase): diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 18693a7bb..419e7b18d 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Sequence -import contextlib import io import re import textwrap @@ -29,6 +28,7 @@ import jax.numpy as jnp import numpy as np jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(2) def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]: fake_stdin = io.StringIO() @@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI def _format_multiline(text): return textwrap.dedent(text).lstrip() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - foo = 2 class CliDebuggerTest(jtu.JaxTestCase): diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 6afb41645..0fc9665ce 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections -import contextlib import functools import textwrap import unittest @@ -35,19 +34,13 @@ except ModuleNotFoundError: rich = None jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(2) debug_print = debugging.debug_print def _format_multiline(text): return textwrap.dedent(text).lstrip() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() class DummyDevice: def __init__(self, platform, id): diff --git a/tests/export_test.py b/tests/export_test.py index da0e9daf2..63fe4a8bc 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -56,14 +56,8 @@ except (ModuleNotFoundError, ImportError): CAN_SERIALIZE = False config.parse_flags_with_absl() +jtu.request_cpu_devices(8) -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() ### Setup for testing lowering with effects @dataclasses.dataclass(frozen=True) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 2e91792aa..922b37ffa 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib + import threading import unittest @@ -34,6 +34,7 @@ from jax._src.interpreters import partial_eval as pe import numpy as np config.parse_flags_with_absl() +jtu.request_cpu_devices(2) effect_p = core.Primitive('effect') effect_p.multiple_results = True @@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out mlir.register_lowering(callback_p, callback_effect_lowering) -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - - class JaxprEffectsTest(jtu.JaxTestCase): def test_trivial_jaxpr_has_no_effects(self): diff --git a/tests/layout_test.py b/tests/layout_test.py index f958de5cf..903b17886 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import math from functools import partial from absl.testing import absltest @@ -28,14 +27,7 @@ from jax._src.util import safe_zip from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class LayoutTest(jtu.JaxTestCase): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 057731cb5..1fc6fe1e9 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from unittest import SkipTest import tracemalloc as tm @@ -25,15 +24,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax._src import test_util as jtu jax.config.parse_flags_with_absl() - -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class MultiDeviceTest(jtu.JaxTestCase): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 77bcee296..3fcc5c81a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import OrderedDict, namedtuple -import contextlib import re from functools import partial import logging @@ -64,14 +63,7 @@ from jax._src.util import curry, unzip2 config.parse_flags_with_absl() -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) def create_array(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index a9de8c896..795f7d4bf 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -import contextlib from functools import partial import itertools as it import gc @@ -54,15 +53,8 @@ from jax._src.lib import xla_extension from jax._src.util import safe_map, safe_zip config.parse_flags_with_absl() +jtu.request_cpu_devices(8) -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 199b90fe5..efa877fd3 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -36,14 +36,7 @@ from jax.sharding import Mesh import numpy as np config.parse_flags_with_absl() - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(2) map, unsafe_map = util.safe_map, map diff --git a/tests/roofline_test.py b/tests/roofline_test.py index e50039471..aec34ff22 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,7 +14,6 @@ from __future__ import annotations from functools import partial -import contextlib from absl.testing import absltest from jax.sharding import PartitionSpec as P @@ -28,6 +27,7 @@ from jax.experimental import roofline jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) def create_inputs( @@ -45,18 +45,6 @@ def create_inputs( return mesh, tuple(arrays) -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - - -def tearDownModule(): - _exit_stack.close() - - class RooflineTest(jtu.JaxTestCase): def test_scalar_collectives(self): a_spec = P("z", ("x", "y")) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 383746899..25d46c5ad 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import jax import jax.numpy as jnp import numpy as np @@ -24,15 +22,7 @@ from jax.experimental.shard_alike import shard_alike from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() - -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class ShardAlikeDownstreamTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ec846a32a..19cc87088 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections.abc import Callable, Generator, Iterable, Iterator, Sequence -import contextlib from functools import partial import itertools as it import math @@ -53,6 +52,7 @@ from jax.experimental.shard_map import shard_map from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member config.parse_flags_with_absl() +jtu.request_cpu_devices(8) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -70,16 +70,6 @@ def create_inputs(a_sharding, b_sharding): return mesh, m1, m2 -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() - - class ShardMapTest(jtu.JaxTestCase): def test_identity(self):