mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #13317 from google:xdist_tpu
PiperOrigin-RevId: 490366370
This commit is contained in:
commit
dd902fde21
7
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
7
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -53,7 +53,12 @@ jobs:
|
||||
- name: Run tests
|
||||
env:
|
||||
JAX_PLATFORMS: tpu,cpu
|
||||
run: python -m pytest --tb=short tests examples
|
||||
run: |
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
|
||||
-m "not multiaccelerator" tests examples
|
||||
# Run multi-accelerator across all chips
|
||||
python -m pytest -m "multiaccelerator" --tb=short tests
|
||||
- name: Send chat on failure
|
||||
# Don't notify when testing the workflow from a branch.
|
||||
if: ${{ failure() && github.ref_name == 'main' }}
|
||||
|
33
conftest.py
33
conftest.py
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""pytest configuration"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@ -24,3 +25,35 @@ def add_imports(doctest_namespace):
|
||||
doctest_namespace["lax"] = jax.lax
|
||||
doctest_namespace["jnp"] = jax.numpy
|
||||
doctest_namespace["np"] = numpy
|
||||
|
||||
|
||||
# A pytest hook that runs immediately before test collection (i.e. when pytest
|
||||
# loads all the test cases to run). When running parallel tests via xdist on
|
||||
# Cloud TPU, we use this hook to set the env vars needed to run multiple test
|
||||
# processes across different TPU chips.
|
||||
#
|
||||
# It's important that the hook runs before test collection, since jax tests end
|
||||
# up initializing the TPU runtime on import (e.g. to query supported test
|
||||
# types). It's also important that the hook gets called by each xdist worker
|
||||
# process. Luckily each worker does its own test collection.
|
||||
#
|
||||
# The pytest_collection hook can be used to overwrite the collection logic, but
|
||||
# we only use it to set the env vars and fall back to the default collection
|
||||
# logic by always returning None. See
|
||||
# https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result
|
||||
# for details.
|
||||
#
|
||||
# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an
|
||||
# effect. We do this to minimize any effect on non-TPU tests, and as a pointer
|
||||
# in test code to this "magic" hook. TPU tests should not specify more xdist
|
||||
# workers than the number of TPU chips.
|
||||
def pytest_collection() -> None:
|
||||
if not os.environ.get("JAX_ENABLE_TPU_XDIST", None):
|
||||
return
|
||||
# When running as an xdist worker, will be something like "gw0"
|
||||
xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "")
|
||||
if not xdist_worker_name.startswith("gw"):
|
||||
return
|
||||
xdist_worker_number = int(xdist_worker_name[len("gw"):])
|
||||
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
|
||||
os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true")
|
||||
|
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
markers =
|
||||
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
|
||||
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
|
||||
filterwarnings =
|
||||
error
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
@ -43,6 +44,11 @@ config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
from functools import partial, lru_cache
|
||||
@ -57,6 +58,11 @@ config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import gc
|
||||
@ -55,6 +56,11 @@ config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
||||
|
||||
def all_bdims(*shapes, pmap):
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Tests for cross host device transfer."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import contextlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
@ -24,6 +25,10 @@ from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
class RemoteTransferTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools as it
|
||||
import os
|
||||
@ -55,6 +56,11 @@ from jax.ad_checkpoint import checkpoint
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
|
Loading…
x
Reference in New Issue
Block a user