[JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.

This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
This commit is contained in:
Peter Hawkins 2025-01-08 06:37:02 -08:00 committed by jax authors
parent f96339be1e
commit 51b9fe3010
18 changed files with 48 additions and 172 deletions

View File

@ -498,29 +498,20 @@ def device_supports_buffer_donation():
)
@contextmanager
def set_host_platform_device_count(nr_devices: int):
"""Context manager to set host platform device count if not specified by user.
def request_cpu_devices(nr_devices: int):
"""Requests at least `nr_devices` CPU devices.
This should only be used by tests at the top level in setUpModule(); it will
not work correctly if applied to individual test cases.
request_cpu_devices should be called at the top-level of a test module before
main() runs.
It is not guaranteed that the number of CPU devices will be exactly
`nr_devices`: it may be more or less, depending on how exactly the test is
invoked. Test cases that require a specific number of devices should skip
themselves if that number is not met.
"""
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
f" --xla_force_host_platform_device_count={nr_devices}")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
try:
yield
finally:
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
xla_bridge.get_backend.cache_clear()
config.update("jax_num_cpu_devices", nr_devices)
def skip_on_flag(flag_name, skip_value):

View File

@ -122,6 +122,14 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
"inline without async dispatch.",
)
NUM_CPU_DEVICES = config.int_flag(
name="jax_num_cpu_devices",
default=-1,
help="Number of CPU devices to use. If not provided, the value of "
"the XLA flag --xla_force_host_platform_device_count is used."
" Must be set before JAX is initialized.",
)
# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
@ -249,8 +257,8 @@ def make_cpu_client(
if collectives is None:
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives_impl = 'gloo'
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
collectives_impl = 'gloo'
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
'deprecated. Please use `jax.config.update('
'"jax_cpu_collectives_implementation", "gloo")` instead.',
DeprecationWarning,
@ -268,12 +276,22 @@ def make_cpu_client(
f"{collectives_impl}. Available implementations are "
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
if xla_client._version < 303 and num_devices is not None:
xla_flags = os.getenv("XLA_FLAGS") or ""
os.environ["XLA_FLAGS"] = (
f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
)
num_devices = None
# TODO(phawkins): pass num_devices directly when version 303 is the minimum.
kwargs = {} if num_devices is None else {"num_devices": num_devices}
return xla_client.make_cpu_client(
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
collectives=collectives,
**kwargs,
)

View File

@ -14,7 +14,6 @@
"""Tests for serialization and deserialization of GDA."""
import asyncio
import contextlib
import math
from functools import partial
import os
@ -36,13 +35,7 @@ import numpy as np
import tensorstore as ts
jax.config.parse_flags_with_absl()
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)
class CheckpointTest(jtu.JaxTestCase):

View File

@ -19,7 +19,6 @@
"""
from collections.abc import Sequence
import contextlib
from functools import partial
import logging
import re
@ -47,16 +46,14 @@ import numpy as np
import tensorflow as tf
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
# Must come after initializing the flags
from jax.experimental.jax2tf.tests import tf_test_util
_exit_stack = contextlib.ExitStack()
topology = None
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
global topology
if jtu.test_device_matches(["tpu"]):
with jtu.ignore_warning(message="the imp module is deprecated"):
@ -67,8 +64,6 @@ def setUpModule():
else:
topology = None
def tearDownModule():
_exit_stack.close()
class ShardingTest(tf_test_util.JaxToTfTestCase):

View File

@ -43,20 +43,12 @@ from jax._src import array
from jax._src import prng
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
def create_array(shape, sharding, global_data=None):
if global_data is None:

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import threading
import time
from typing import Sequence
@ -29,6 +28,7 @@ import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
def _colocated_cpu_devices(
@ -53,18 +53,6 @@ def _colocated_cpu_devices(
_count_colocated_python_specialization_cache_miss = jtu.count_events(
"colocated_python_func._get_specialized_func")
_exit_stack = contextlib.ExitStack()
def setUpModule():
# TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
# prepares CPU devices by its own.
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
class ColocatedPythonTest(jtu.JaxTestCase):

View File

@ -13,7 +13,6 @@
# limitations under the License.
from collections.abc import Sequence
import contextlib
import io
import re
import textwrap
@ -29,6 +28,7 @@ import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
fake_stdin = io.StringIO()
@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI
def _format_multiline(text):
return textwrap.dedent(text).lstrip()
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
def tearDownModule():
_exit_stack.close()
foo = 2
class CliDebuggerTest(jtu.JaxTestCase):

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import functools
import textwrap
import unittest
@ -35,19 +34,13 @@ except ModuleNotFoundError:
rich = None
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
debug_print = debugging.debug_print
def _format_multiline(text):
return textwrap.dedent(text).lstrip()
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
def tearDownModule():
_exit_stack.close()
class DummyDevice:
def __init__(self, platform, id):

View File

@ -56,14 +56,8 @@ except (ModuleNotFoundError, ImportError):
CAN_SERIALIZE = False
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
def tearDownModule():
_exit_stack.close()
### Setup for testing lowering with effects
@dataclasses.dataclass(frozen=True)

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import threading
import unittest
@ -34,6 +34,7 @@ from jax._src.interpreters import partial_eval as pe
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
effect_p = core.Primitive('effect')
effect_p.multiple_results = True
@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
mlir.register_lowering(callback_p, callback_effect_lowering)
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
def tearDownModule():
_exit_stack.close()
class JaxprEffectsTest(jtu.JaxTestCase):
def test_trivial_jaxpr_has_no_effects(self):

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from functools import partial
from absl.testing import absltest
@ -28,14 +27,7 @@ from jax._src.util import safe_zip
from jax.experimental.compute_on import compute_on
config.parse_flags_with_absl()
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)
class LayoutTest(jtu.JaxTestCase):

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from unittest import SkipTest
import tracemalloc as tm
@ -25,15 +24,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)
class MultiDeviceTest(jtu.JaxTestCase):

View File

@ -13,7 +13,6 @@
# limitations under the License.
from collections import OrderedDict, namedtuple
import contextlib
import re
from functools import partial
import logging
@ -64,14 +63,7 @@ from jax._src.util import curry, unzip2
config.parse_flags_with_absl()
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
dtype=np.float32):

View File

@ -15,7 +15,6 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial
import itertools as it
import gc
@ -54,15 +53,8 @@ from jax._src.lib import xla_extension
from jax._src.util import safe_map, safe_zip
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]

View File

@ -36,14 +36,7 @@ from jax.sharding import Mesh
import numpy as np
config.parse_flags_with_absl()
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(2)
map, unsafe_map = util.safe_map, map

View File

@ -14,7 +14,6 @@
from __future__ import annotations
from functools import partial
import contextlib
from absl.testing import absltest
from jax.sharding import PartitionSpec as P
@ -28,6 +27,7 @@ from jax.experimental import roofline
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
def create_inputs(
@ -45,18 +45,6 @@ def create_inputs(
return mesh, tuple(arrays)
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
class RooflineTest(jtu.JaxTestCase):
def test_scalar_collectives(self):
a_spec = P("z", ("x", "y"))

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import jax
import jax.numpy as jnp
import numpy as np
@ -24,15 +22,7 @@ from jax.experimental.shard_alike import shard_alike
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)
class ShardAlikeDownstreamTest(jtu.JaxTestCase):

View File

@ -15,7 +15,6 @@
from __future__ import annotations
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
import contextlib
from functools import partial
import itertools as it
import math
@ -53,6 +52,7 @@ from jax.experimental.shard_map import shard_map
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -70,16 +70,6 @@ def create_inputs(a_sharding, b_sharding):
return mesh, m1, m2
# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
class ShardMapTest(jtu.JaxTestCase):
def test_identity(self):