mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS. In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads. PiperOrigin-RevId: 713272197
This commit is contained in:
parent
f96339be1e
commit
51b9fe3010
@ -498,29 +498,20 @@ def device_supports_buffer_donation():
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_host_platform_device_count(nr_devices: int):
|
||||
"""Context manager to set host platform device count if not specified by user.
|
||||
def request_cpu_devices(nr_devices: int):
|
||||
"""Requests at least `nr_devices` CPU devices.
|
||||
|
||||
This should only be used by tests at the top level in setUpModule(); it will
|
||||
not work correctly if applied to individual test cases.
|
||||
request_cpu_devices should be called at the top-level of a test module before
|
||||
main() runs.
|
||||
|
||||
It is not guaranteed that the number of CPU devices will be exactly
|
||||
`nr_devices`: it may be more or less, depending on how exactly the test is
|
||||
invoked. Test cases that require a specific number of devices should skip
|
||||
themselves if that number is not met.
|
||||
"""
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
flags_str = prev_xla_flags or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in flags_str:
|
||||
os.environ["XLA_FLAGS"] = (flags_str +
|
||||
f" --xla_force_host_platform_device_count={nr_devices}")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev_xla_flags is None:
|
||||
del os.environ["XLA_FLAGS"]
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
config.update("jax_num_cpu_devices", nr_devices)
|
||||
|
||||
|
||||
def skip_on_flag(flag_name, skip_value):
|
||||
|
@ -122,6 +122,14 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
|
||||
"inline without async dispatch.",
|
||||
)
|
||||
|
||||
NUM_CPU_DEVICES = config.int_flag(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help="Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized.",
|
||||
)
|
||||
|
||||
|
||||
# Warn the user if they call fork(), because it's not going to go well for them.
|
||||
def _at_fork():
|
||||
@ -249,8 +257,8 @@ def make_cpu_client(
|
||||
if collectives is None:
|
||||
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
|
||||
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
|
||||
collectives_impl = 'gloo'
|
||||
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
|
||||
collectives_impl = 'gloo'
|
||||
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
|
||||
'deprecated. Please use `jax.config.update('
|
||||
'"jax_cpu_collectives_implementation", "gloo")` instead.',
|
||||
DeprecationWarning,
|
||||
@ -268,12 +276,22 @@ def make_cpu_client(
|
||||
f"{collectives_impl}. Available implementations are "
|
||||
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
|
||||
|
||||
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
|
||||
if xla_client._version < 303 and num_devices is not None:
|
||||
xla_flags = os.getenv("XLA_FLAGS") or ""
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
|
||||
)
|
||||
num_devices = None
|
||||
# TODO(phawkins): pass num_devices directly when version 303 is the minimum.
|
||||
kwargs = {} if num_devices is None else {"num_devices": num_devices}
|
||||
return xla_client.make_cpu_client(
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
collectives=collectives,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
"""Tests for serialization and deserialization of GDA."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import math
|
||||
from functools import partial
|
||||
import os
|
||||
@ -36,13 +35,7 @@ import numpy as np
|
||||
import tensorstore as ts
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
class CheckpointTest(jtu.JaxTestCase):
|
||||
|
@ -19,7 +19,6 @@
|
||||
|
||||
"""
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import logging
|
||||
import re
|
||||
@ -47,16 +46,14 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
# Must come after initializing the flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
topology = None
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
global topology
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
with jtu.ignore_warning(message="the imp module is deprecated"):
|
||||
@ -67,8 +64,6 @@ def setUpModule():
|
||||
else:
|
||||
topology = None
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -43,20 +43,12 @@ from jax._src import array
|
||||
from jax._src import prng
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
def create_array(shape, sharding, global_data=None):
|
||||
if global_data is None:
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from typing import Sequence
|
||||
@ -29,6 +28,7 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
def _colocated_cpu_devices(
|
||||
@ -53,18 +53,6 @@ def _colocated_cpu_devices(
|
||||
_count_colocated_python_specialization_cache_miss = jtu.count_events(
|
||||
"colocated_python_func._get_specialized_func")
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
|
||||
def setUpModule():
|
||||
# TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
|
||||
# prepares CPU devices by its own.
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import io
|
||||
import re
|
||||
import textwrap
|
||||
@ -29,6 +28,7 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(2)
|
||||
|
||||
def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
|
||||
fake_stdin = io.StringIO()
|
||||
@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI
|
||||
def _format_multiline(text):
|
||||
return textwrap.dedent(text).lstrip()
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
foo = 2
|
||||
|
||||
class CliDebuggerTest(jtu.JaxTestCase):
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import textwrap
|
||||
import unittest
|
||||
@ -35,19 +34,13 @@ except ModuleNotFoundError:
|
||||
rich = None
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(2)
|
||||
|
||||
debug_print = debugging.debug_print
|
||||
|
||||
def _format_multiline(text):
|
||||
return textwrap.dedent(text).lstrip()
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
class DummyDevice:
|
||||
def __init__(self, platform, id):
|
||||
|
@ -56,14 +56,8 @@ except (ModuleNotFoundError, ImportError):
|
||||
CAN_SERIALIZE = False
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
### Setup for testing lowering with effects
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
@ -34,6 +34,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(2)
|
||||
|
||||
effect_p = core.Primitive('effect')
|
||||
effect_p.multiple_results = True
|
||||
@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
|
||||
mlir.register_lowering(callback_p, callback_effect_lowering)
|
||||
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class JaxprEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_trivial_jaxpr_has_no_effects(self):
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
from functools import partial
|
||||
from absl.testing import absltest
|
||||
@ -28,14 +27,7 @@ from jax._src.util import safe_zip
|
||||
from jax.experimental.compute_on import compute_on
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
class LayoutTest(jtu.JaxTestCase):
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
from unittest import SkipTest
|
||||
import tracemalloc as tm
|
||||
|
||||
@ -25,15 +24,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
class MultiDeviceTest(jtu.JaxTestCase):
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict, namedtuple
|
||||
import contextlib
|
||||
import re
|
||||
from functools import partial
|
||||
import logging
|
||||
@ -64,14 +63,7 @@ from jax._src.util import curry, unzip2
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
|
||||
dtype=np.float32):
|
||||
|
@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import gc
|
||||
@ -54,15 +53,8 @@ from jax._src.lib import xla_extension
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
||||
|
||||
|
@ -36,14 +36,7 @@ from jax.sharding import Mesh
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
jtu.request_cpu_devices(2)
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import contextlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -28,6 +27,7 @@ from jax.experimental import roofline
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
def create_inputs(
|
||||
@ -45,18 +45,6 @@ def create_inputs(
|
||||
return mesh, tuple(arrays)
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class RooflineTest(jtu.JaxTestCase):
|
||||
def test_scalar_collectives(self):
|
||||
a_spec = P("z", ("x", "y"))
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -24,15 +22,7 @@ from jax.experimental.shard_alike import shard_alike
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
class ShardAlikeDownstreamTest(jtu.JaxTestCase):
|
||||
|
@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import math
|
||||
@ -53,6 +52,7 @@ from jax.experimental.shard_map import shard_map
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -70,16 +70,6 @@ def create_inputs(a_sharding, b_sharding):
|
||||
return mesh, m1, m2
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_identity(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user