diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml new file mode 100644 index 000000000..bc19a6e85 --- /dev/null +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -0,0 +1,93 @@ +# Cloud TPU CI (presubmit) +# +# This job currently runs as a non-blocking presubmit. It is experimental and is currently being +# tested to get to a stable state before we enable it as a blocking presubmit. +name: CI - Cloud TPU (presubmit) +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + +# This should also be set to read-only in the project settings, but it's nice to +# document and enforce the permissions here. +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + cloud-tpu-test: + if: github.event.repository.fork == false + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + tpu: [ + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + ] + python-version: ["3.10"] + + name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})" + + env: + JAXCI_PYTHON: python${{ matrix.python-version }} + JAXCI_TPU_CORES: ${{ matrix.tpu.cores }} + + runs-on: ${{ matrix.tpu.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + timeout-minutes: 60 + + defaults: + run: + shell: bash -ex {0} + + steps: + # https://opensource.google/documentation/reference/github/services#actions + # mandates using a specific commit for non-Google actions. We use + # https://github.com/sethvargo/ratchet to pin specific versions. + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Checkout XLA at head, if we're building jaxlib at head. + - name: Checkout XLA at head + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: openxla/xla + path: xla + # We need to mark the GitHub workspace as safe as otherwise git commands will fail. + - name: Mark GitHub workspace as safe + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + - name: Install JAX test requirements + run: | + $JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt + $JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt + - name: Build jaxlib at head with latest XLA + run: | + # Build and install jaxlib at head + $JAXCI_PYTHON build/build.py build --wheels=jaxlib \ + --python_version=${{ matrix.python-version }} \ + --bazel_options=--config=rbe_linux_x86_64 \ + --local_xla_path="$(pwd)/xla" \ + --verbose + + # Install libtpu + $JAXCI_PYTHON -m pip install --pre libtpu \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Install jaxlib wheel and run tests + run: ./ci/run_pytest_tpu.sh diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh old mode 100644 new mode 100755 index 783d2f9fe..94d1eb987 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -33,29 +33,28 @@ source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" -export PY_COLORS=1 -export JAX_SKIP_SLOW_TESTS=true - "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" - "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' "$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' -strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on' +strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' -echo "Running TPU tests..." +# Set up all common test environment variables +export PY_COLORS=1 export JAX_PLATFORMS=tpu,cpu -# Run single-accelerator tests in parallel -export JAX_ENABLE_TPU_XDIST=true +export JAX_SKIP_SLOW_TESTS=true +# End of common test environment variable setup -"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ +echo "Running TPU tests..." + +# Run single-accelerator tests in parallel +JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ ---maxfail=20 -m "not multiaccelerator" tests examples +--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py # Run Pallas printing tests, which need to run with I/O capturing disabled. -export TPU_STDERR_LOG_LEVEL=0 -"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest +TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest # Run multi-accelerator across all chips -"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests \ No newline at end of file +"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py \ No newline at end of file