Merge pull request #13317 from google:xdist_tpu

PiperOrigin-RevId: 490366370
This commit is contained in:
jax authors 2022-11-22 16:40:00 -08:00
commit dd902fde21
8 changed files with 69 additions and 1 deletions

View File

@ -53,7 +53,12 @@ jobs:
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
run: python -m pytest --tb=short tests examples
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
-m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python -m pytest -m "multiaccelerator" --tb=short tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ failure() && github.ref_name == 'main' }}

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""pytest configuration"""
import os
import pytest
@ -24,3 +25,35 @@ def add_imports(doctest_namespace):
doctest_namespace["lax"] = jax.lax
doctest_namespace["jnp"] = jax.numpy
doctest_namespace["np"] = numpy
# A pytest hook that runs immediately before test collection (i.e. when pytest
# loads all the test cases to run). When running parallel tests via xdist on
# Cloud TPU, we use this hook to set the env vars needed to run multiple test
# processes across different TPU chips.
#
# It's important that the hook runs before test collection, since jax tests end
# up initializing the TPU runtime on import (e.g. to query supported test
# types). It's also important that the hook gets called by each xdist worker
# process. Luckily each worker does its own test collection.
#
# The pytest_collection hook can be used to overwrite the collection logic, but
# we only use it to set the env vars and fall back to the default collection
# logic by always returning None. See
# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result
# for details.
#
# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an
# effect. We do this to minimize any effect on non-TPU tests, and as a pointer
# in test code to this "magic" hook. TPU tests should not specify more xdist
# workers than the number of TPU chips.
def pytest_collection() -> None:
if not os.environ.get("JAX_ENABLE_TPU_XDIST", None):
return
# When running as an xdist worker, will be something like "gw0"
xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "")
if not xdist_worker_name.startswith("gw"):
return
xdist_worker_number = int(xdist_worker_name[len("gw"):])
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true")

View File

@ -1,5 +1,6 @@
[pytest]
markers =
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
filterwarnings =
error

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for GlobalDeviceArray."""
import contextlib
import os
import unittest
from absl.testing import absltest
@ -43,6 +44,11 @@ config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import re
from functools import partial, lru_cache
@ -57,6 +58,11 @@ config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")

View File

@ -14,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial
import itertools as it
import gc
@ -55,6 +56,11 @@ config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
def all_bdims(*shapes, pmap):

View File

@ -14,6 +14,7 @@
"""Tests for cross host device transfer."""
from absl.testing import absltest
import contextlib
import unittest
import numpy as np
@ -24,6 +25,10 @@ from jax.config import config
config.parse_flags_with_absl()
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
class RemoteTransferTest(jtu.JaxTestCase):

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import itertools as it
import os
@ -55,6 +56,11 @@ from jax.ad_checkpoint import checkpoint
from jax.config import config
config.parse_flags_with_absl()
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
# Run all tests with 8 CPU devices.
def setUpModule():