Make pytest-xdist work on TPU and update Cloud TPU CI.

This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
This commit is contained in:
Skye Wanderman-Milne 2022-11-17 05:33:54 +00:00
parent 3a837c8069
commit 120125f3dd
8 changed files with 70 additions and 2 deletions

View File

@ -53,7 +53,12 @@ jobs:
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
run: python -m pytest --tb=short tests examples
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
-m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python -m pytest -m "multiaccelerator" --tb=short tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ failure() && github.ref_name == 'main' }}

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""pytest configuration"""
import os
import pytest
@ -24,3 +25,35 @@ def add_imports(doctest_namespace):
doctest_namespace["lax"] = jax.lax
doctest_namespace["jnp"] = jax.numpy
doctest_namespace["np"] = numpy
# A pytest hook that runs immediately before test collection (i.e. when pytest
# loads all the test cases to run). When running parallel tests via xdist on
# Cloud TPU, we use this hook to set the env vars needed to run multiple test
# processes across different TPU chips.
#
# It's important that the hook runs before test collection, since jax tests end
# up initializing the TPU runtime on import (e.g. to query supported test
# types). It's also important that the hook gets called by each xdist worker
# process. Luckily each worker does its own test collection.
#
# The pytest_collection hook can be used to overwrite the collection logic, but
# we only use it to set the env vars and fall back to the default collection
# logic by always returning None. See
# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result
# for details.
#
# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an
# effect. We do this to minimize any effect on non-TPU tests, and as a pointer
# in test code to this "magic" hook. TPU tests should not specify more xdist
# workers than the number of TPU chips.
def pytest_collection() -> None:
if not os.environ.get("JAX_ENABLE_TPU_XDIST", None):
return
# When running as an xdist worker, will be something like "gw0"
xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "")
if not xdist_worker_name.startswith("gw"):
return
xdist_worker_number = int(xdist_worker_name[len("gw"):])
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true")

View File

@ -1,5 +1,6 @@
[pytest]
markers =
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
filterwarnings =
error
@ -19,7 +20,7 @@ filterwarnings =
# numpy uses distutils which is deprecated
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
default:Error reading persistent compilation cache entry for 'jit__lambda_'
default:Error writing persistent compilation cache entry for 'jit__lambda_'
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for GlobalDeviceArray."""
import contextlib
import os
import unittest
from absl.testing import absltest
@ -43,6 +44,11 @@ config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import re
from functools import partial, lru_cache
@ -57,6 +58,11 @@ config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")

View File

@ -14,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial
import itertools as it
import gc
@ -55,6 +56,11 @@ config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
def all_bdims(*shapes, pmap):

View File

@ -14,6 +14,7 @@
"""Tests for cross host device transfer."""
from absl.testing import absltest
import contextlib
import unittest
import numpy as np
@ -24,6 +25,10 @@ from jax.config import config
config.parse_flags_with_absl()
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
class RemoteTransferTest(jtu.JaxTestCase):

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import itertools as it
import os
@ -55,6 +56,11 @@ from jax.ad_checkpoint import checkpoint
from jax.config import config
config.parse_flags_with_absl()
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
# Run all tests with 8 CPU devices.
def setUpModule():