mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can understand (if pytest is installed). By running single-device tests on a single TPU chip, running the test suite goes from 1hr 45m to 35m (both timings are running slow tests). I tried using bazel at first, which already supported parallel execution across TPU cores, but somehow it still takes 2h 20m! I'm not sure why it's so slow. It appears that bazel creates many new test processes over time, vs. pytest reuses the number of processes initially specified, and starting and stopping the TPU runtime takes a few seconds so that may be adding up. It also appears that single-process bazel is slower than single-process pytest, which I haven't looked into yet.
This commit is contained in:
parent
3a837c8069
commit
120125f3dd
7
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
7
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -53,7 +53,12 @@ jobs:
|
||||
- name: Run tests
|
||||
env:
|
||||
JAX_PLATFORMS: tpu,cpu
|
||||
run: python -m pytest --tb=short tests examples
|
||||
run: |
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
|
||||
-m "not multiaccelerator" tests examples
|
||||
# Run multi-accelerator across all chips
|
||||
python -m pytest -m "multiaccelerator" --tb=short tests
|
||||
- name: Send chat on failure
|
||||
# Don't notify when testing the workflow from a branch.
|
||||
if: ${{ failure() && github.ref_name == 'main' }}
|
||||
|
33
conftest.py
33
conftest.py
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""pytest configuration"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@ -24,3 +25,35 @@ def add_imports(doctest_namespace):
|
||||
doctest_namespace["lax"] = jax.lax
|
||||
doctest_namespace["jnp"] = jax.numpy
|
||||
doctest_namespace["np"] = numpy
|
||||
|
||||
|
||||
# A pytest hook that runs immediately before test collection (i.e. when pytest
|
||||
# loads all the test cases to run). When running parallel tests via xdist on
|
||||
# Cloud TPU, we use this hook to set the env vars needed to run multiple test
|
||||
# processes across different TPU chips.
|
||||
#
|
||||
# It's important that the hook runs before test collection, since jax tests end
|
||||
# up initializing the TPU runtime on import (e.g. to query supported test
|
||||
# types). It's also important that the hook gets called by each xdist worker
|
||||
# process. Luckily each worker does its own test collection.
|
||||
#
|
||||
# The pytest_collection hook can be used to overwrite the collection logic, but
|
||||
# we only use it to set the env vars and fall back to the default collection
|
||||
# logic by always returning None. See
|
||||
# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result
|
||||
# for details.
|
||||
#
|
||||
# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an
|
||||
# effect. We do this to minimize any effect on non-TPU tests, and as a pointer
|
||||
# in test code to this "magic" hook. TPU tests should not specify more xdist
|
||||
# workers than the number of TPU chips.
|
||||
def pytest_collection() -> None:
|
||||
if not os.environ.get("JAX_ENABLE_TPU_XDIST", None):
|
||||
return
|
||||
# When running as an xdist worker, will be something like "gw0"
|
||||
xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "")
|
||||
if not xdist_worker_name.startswith("gw"):
|
||||
return
|
||||
xdist_worker_number = int(xdist_worker_name[len("gw"):])
|
||||
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
|
||||
os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true")
|
||||
|
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
markers =
|
||||
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
|
||||
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
|
||||
filterwarnings =
|
||||
error
|
||||
@ -19,7 +20,7 @@ filterwarnings =
|
||||
# numpy uses distutils which is deprecated
|
||||
ignore:The distutils.* is deprecated.*:DeprecationWarning
|
||||
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
|
||||
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
|
||||
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
|
||||
default:Error reading persistent compilation cache entry for 'jit__lambda_'
|
||||
default:Error writing persistent compilation cache entry for 'jit__lambda_'
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
@ -43,6 +44,11 @@ config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
from functools import partial, lru_cache
|
||||
@ -57,6 +58,11 @@ config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import gc
|
||||
@ -55,6 +56,11 @@ config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
||||
|
||||
def all_bdims(*shapes, pmap):
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Tests for cross host device transfer."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import contextlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
@ -24,6 +25,10 @@ from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
class RemoteTransferTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools as it
|
||||
import os
|
||||
@ -55,6 +56,11 @@ from jax.ad_checkpoint import checkpoint
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
|
Loading…
x
Reference in New Issue
Block a user