mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Multinodes CICD on GPUs using on-demand cluster and e2e tests using T5X
This commit is contained in:
parent
8da6c89c7b
commit
056702c1cb
45
.github/workflows/cat_slurm_logs.py
vendored
45
.github/workflows/cat_slurm_logs.py
vendored
@ -1,45 +0,0 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Script used in the nightly-ci-multiprocess-gpu workflow to process logs."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
ISSUE_FORMAT = """\
|
||||
<details><summary>Failure summary {name}</summary>
|
||||
|
||||
```
|
||||
{content}
|
||||
```
|
||||
|
||||
</details>
|
||||
"""
|
||||
|
||||
def main(logfiles: List[str], outfile: str):
|
||||
print(f"extracting content of {logfiles}")
|
||||
print(f"and writing to {outfile}")
|
||||
with open(outfile, 'w') as f:
|
||||
for logfile in logfiles:
|
||||
content = open(logfile).read()
|
||||
f.write(ISSUE_FORMAT.format(name=os.path.basename(logfile), content=content))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("logfiles", nargs="+", help="The path to the input logfiles")
|
||||
parser.add_argument("--outfile", help="The path to the parsed output file to be created.",
|
||||
default="parsed_logs.txt")
|
||||
args = parser.parse_args()
|
||||
main(logfiles=args.logfiles, outfile=args.outfile)
|
181
.github/workflows/nightly-ci-multiprocess-gpu.yml
vendored
181
.github/workflows/nightly-ci-multiprocess-gpu.yml
vendored
@ -1,91 +1,140 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: Nightly JAX CI on NVIDIA GPUs
|
||||
# This configures running JAX tests (especially multi-node multi-gpu) against nightly GPU jaxlib builds.
|
||||
# This is expected to fail frequently, and so we don't run it against every commit and PR in the repository.
|
||||
# Portions of this adapted from https://github.com/google/jax/blob/main/.github/workflows/upstream-nightly.yaml
|
||||
|
||||
# Controls when the workflow will run
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 12 * * *" # Daily at 12:00 UTC
|
||||
workflow_dispatch: # allows triggering the workflow run manually
|
||||
pull_request: # Automatically trigger on pull requests affecting this file
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**workflows/nightly-ci-multiprocess-gpu.yml'
|
||||
jobs:
|
||||
jaxlib-nightly:
|
||||
runs-on: self-hosted
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
# Secrets to connect and authenticate with the Cluster manager web API
|
||||
WEB_API_TOKEN: ${{ secrets.NV_CLUSTER_API_TOKEN }}
|
||||
WEB_API_URL: ${{ secrets.NV_CLUSTER_API_URL }}
|
||||
|
||||
GITHUB_WORKSPACE_REMOTE: "~/jax_ci_${{ github.run_id }}_${{ github.run_attempt }}/"
|
||||
CONFIG: "-F ${{ github.workspace }}/.ssh/config -o UserKnownHostsFile=${{ github.workspace }}/.ssh/known_hosts"
|
||||
steps:
|
||||
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
|
||||
- name: Launch slurm job and hook output to this shell
|
||||
run: |
|
||||
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
|
||||
source $JOBSCRIPTSDIR/slurm_utils_common.sh
|
||||
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest_jaxlib_nightly.sub | tee output.log
|
||||
sleep 2m
|
||||
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
|
||||
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
|
||||
job_wait "${SLURM_JOBID}" & PID=$!
|
||||
touch "${SLURM_OUTPUT}"
|
||||
echo -e " ---------------------------------------------------\n" \
|
||||
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
|
||||
"---------------------------------------------------\n"
|
||||
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
|
||||
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
|
||||
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
|
||||
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
|
||||
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
|
||||
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
|
||||
- name: Publish Test Results
|
||||
uses: EnricoMi/publish-unit-test-result-action@46ab8d49369d898e381a607119161771bc65c2a6 # ratchet:EnricoMi/publish-unit-test-result-action@v2
|
||||
if: always()
|
||||
- uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
|
||||
with:
|
||||
junit_files: "outputs/*.xml"
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Create credentials
|
||||
run: |
|
||||
|
||||
# Create SSH keys
|
||||
mkdir -p ./.ssh && chmod 700 ./.ssh
|
||||
ssh-keygen -N '' -f ./.ssh/id_rsa
|
||||
|
||||
- name: Create cluster
|
||||
run: |
|
||||
|
||||
# Setup cluster, get username and IP address
|
||||
pip install -r "./.github/workflows/slurm_job_scripts/requirements.txt"
|
||||
python3 ./.github/workflows/slurm_job_scripts/oci_cluster_manager.py create_cluster --pubkey "$(cat ./.ssh/id_rsa.pub)" &> oci_automation_create.log
|
||||
USER=$(tail -n 2 oci_automation_create.log | head -n 1)
|
||||
IP=$(tail -n 1 oci_automation_create.log)
|
||||
|
||||
# Hide IP address from logs
|
||||
echo "::add-mask::${IP}"
|
||||
|
||||
# Create SSH config
|
||||
grep "^${IP} " oci_automation_create.log >> ./.ssh/known_hosts
|
||||
echo "Host headnode
|
||||
User ${USER}
|
||||
HostName ${IP}
|
||||
IdentityFile ${GITHUB_WORKSPACE}/.ssh/id_rsa" > ./.ssh/config
|
||||
|
||||
- name: Check SLURM is working
|
||||
run: |
|
||||
|
||||
# SSH into the cluser & check SLURM
|
||||
ssh ${CONFIG} headnode sinfo
|
||||
|
||||
# Run dummy job
|
||||
SRUN="srun --container-name=nvidia --container-image=docker://nvcr.io#nvidia/tensorflow:22.11-tf2-py3 -N 2 -t 15:00 --gpus-per-node=8 --cpus-per-task=8 --ntasks-per-node=8"
|
||||
CMD="bash -c 'hostname && nvidia-smi --query-gpu=gpu_name,driver_version --format=csv'"
|
||||
ssh ${CONFIG} headnode "${SRUN} ${CMD}"
|
||||
|
||||
- name: Copy workspace
|
||||
run: |
|
||||
|
||||
ssh ${CONFIG} headnode "rm -rf ${GITHUB_WORKSPACE_REMOTE} && mkdir -p ${GITHUB_WORKSPACE_REMOTE}"
|
||||
scp ${CONFIG} -r ./.github headnode:${GITHUB_WORKSPACE_REMOTE}/.github
|
||||
scp ${CONFIG} -r ./tests headnode:${GITHUB_WORKSPACE_REMOTE}/tests
|
||||
scp ${CONFIG} -r ./pytest* headnode:${GITHUB_WORKSPACE_REMOTE}/
|
||||
|
||||
- name: T5X end-to-end tests
|
||||
timeout-minutes: 25
|
||||
run: |
|
||||
|
||||
ENV="GITHUB_WORKSPACE_REMOTE=${GITHUB_WORKSPACE_REMOTE}"
|
||||
SALLOC="salloc -N 2 --gpus-per-node=8 --exclusive -t 0:20:00 -p compute"
|
||||
CMD="bash ${GITHUB_WORKSPACE_REMOTE}/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub"
|
||||
ssh ${CONFIG} headnode "${ENV} ${SALLOC} ${CMD}"
|
||||
|
||||
- name: Gather results
|
||||
if: always()
|
||||
run: |
|
||||
|
||||
scp ${CONFIG} -r headnode:${GITHUB_WORKSPACE_REMOTE}/outputs ./
|
||||
|
||||
- name: Destroy cluster
|
||||
if: always()
|
||||
run: |
|
||||
|
||||
pip install -r "./.github/workflows/slurm_job_scripts/requirements.txt"
|
||||
python3 ./.github/workflows/slurm_job_scripts/oci_cluster_manager.py destroy_clusters &> ./oci_automation_destroy.log
|
||||
|
||||
- name: Upload run results from all nodes
|
||||
uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # ratchet:actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: output-from-nodes
|
||||
path: "outputs/*.txt"
|
||||
jaxlib-release:
|
||||
runs-on: self-hosted
|
||||
needs: jaxlib-nightly
|
||||
|
||||
report-metrics:
|
||||
name: e2e-tests-metrics
|
||||
needs: tests
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
|
||||
- name: Launch slurm job and hook output to this shell
|
||||
- uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- uses: actions/download-artifact@9782bd6a9848b53b110e712e20e42d89988822b7 # ratchet:actions/download-artifact@v3
|
||||
with:
|
||||
path: /tmp/workspace/logs
|
||||
- name: Parse log output
|
||||
run: |
|
||||
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
|
||||
source $JOBSCRIPTSDIR/slurm_utils_common.sh
|
||||
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest_jaxlib_release.sub | tee output.log
|
||||
sleep 2m
|
||||
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
|
||||
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
|
||||
job_wait "${SLURM_JOBID}" & PID=$!
|
||||
touch "${SLURM_OUTPUT}"
|
||||
echo -e " ---------------------------------------------------\n" \
|
||||
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
|
||||
"---------------------------------------------------\n"
|
||||
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
|
||||
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
|
||||
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
|
||||
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
|
||||
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
|
||||
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
|
||||
- name: Publish Test Results
|
||||
uses: EnricoMi/publish-unit-test-result-action@46ab8d49369d898e381a607119161771bc65c2a6 # ratchet:EnricoMi/publish-unit-test-result-action@v2
|
||||
if: always()
|
||||
with:
|
||||
junit_files: "outputs/*.xml"
|
||||
- name: Upload run results from all nodes
|
||||
uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # ratchet:actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: output-from-nodes
|
||||
path: "outputs/*.txt"
|
||||
ls /tmp/workspace/logs/output-from-nodes/ && mv /tmp/workspace/logs/output-from-nodes/output*t5x*1-0-0.txt ${GITHUB_WORKSPACE}/output.log
|
||||
pip install -r "${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts/requirements.txt"
|
||||
python ${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts/extract_e2e_tests_metrics.py --logfile ${GITHUB_WORKSPACE}/output.log --outmd ${GITHUB_WORKSPACE}/report.md --outjson ${GITHUB_WORKSPACE}/report.json --name end-to-end-t5x
|
||||
cat report.md >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
report:
|
||||
name: report
|
||||
needs: [jaxlib-nightly, jaxlib-release]
|
||||
needs: tests
|
||||
if: |
|
||||
failure()
|
||||
&& github.event_name == 'schedule'
|
||||
|
61
.github/workflows/slurm_job_scripts/extract_e2e_tests_metrics.py
vendored
Normal file
61
.github/workflows/slurm_job_scripts/extract_e2e_tests_metrics.py
vendored
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Script used in the nightly-ci-multiprocess-gpu workflow to process logs."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import datetime
|
||||
import pandas as pd
|
||||
|
||||
stats_pat = r".*collection=train .*(timing/seconds=[\d.]+), (timing/seqs=[\d.]+), (timing/seqs_per_second=[\d.]+), (timing/seqs_per_second_per_core=[\d.]+), (timing/steps_per_second=[\d.]+), (timing/target_tokens_per_second=[\d.]+), (timing/target_tokens_per_second_per_core=[\d.]+).*"
|
||||
|
||||
def main(logfile: str, outmd: str, outjson: str, name: str):
|
||||
print(f"Extracting content of {logfile}")
|
||||
print(f"and writing to {outmd} and {outjson}")
|
||||
|
||||
with open(logfile, 'r') as fp:
|
||||
lines = fp.read()
|
||||
stats = re.findall(stats_pat,lines)
|
||||
|
||||
data_parsed = [
|
||||
# Extract `metric` and `value` from `timings/metric=value`
|
||||
{re.split('=|/',s)[1] : float(re.split('=|/',s)[2]) for s in stat}
|
||||
for stat in stats
|
||||
]
|
||||
df = pd.DataFrame(data_parsed).reset_index(drop=True)
|
||||
df.to_markdown(outmd, index=False)
|
||||
|
||||
data = {
|
||||
'name': name,
|
||||
'date': datetime.datetime.now(tz=None).isoformat(),
|
||||
'data': data_parsed,
|
||||
'github': {k:v for (k,v) in os.environ.items() if k.startswith('GITHUB')}
|
||||
}
|
||||
|
||||
with open(outjson, "w") as ofile:
|
||||
ofile.write(json.dumps(data, indent=4))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--logfile", help="The path to the input logfile")
|
||||
parser.add_argument("--outmd", help="The path to the parsed output markdown file to be created.",
|
||||
default="metrics_report.md")
|
||||
parser.add_argument("--outjson", help="The path to the parsed output json file to be created.",
|
||||
default="metrics_report.json")
|
||||
parser.add_argument("--name", help="Name of the benchmark to be added to the JSON.")
|
||||
args = parser.parse_args()
|
||||
main(logfile=args.logfile, outmd=args.outmd, outjson=args.outjson, name=args.name)
|
@ -1,81 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH -A ci-jax-gpu
|
||||
#SBATCH -p compute
|
||||
#SBATCH -N 2 # number of nodes
|
||||
#SBATCH -t 00:15:00 # wall time
|
||||
#SBATCH -J "ci-jax-gpu" # job name
|
||||
#SBATCH --exclusive # exclusive node access
|
||||
#SBATCH --mem=0 # all mem avail
|
||||
#SBATCH --mail-type=FAIL # only send email on failures
|
||||
#SBATCH --overcommit # Needed for pytorch
|
||||
|
||||
set -x
|
||||
|
||||
# File system and volume glue code
|
||||
#-------------------------------------------------------------------------------
|
||||
CONTAINER="nvcr.io/nvidian/jax_t5x:cuda11.4-cudnn8.2-ubuntu20.04-manylinux2014-multipython"
|
||||
CONTAINER_NAME="multinode_ci_test_container"
|
||||
|
||||
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
|
||||
WORKSPACE_DIR=/workspace
|
||||
|
||||
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
|
||||
|
||||
# Since the docker container doesn't contain MLX drivers for IB, following flags
|
||||
# are needed to make NCCL work with an ethernet setup
|
||||
# Note:@sudhakarsingh27 This is very specific, need to abstract this out
|
||||
EXPORTS="--export=ALL,NCCL_SOCKET_IFNAME=enp45s0f0,NCCL_SOCKET_NTHREADS=2,NCCL_NSOCKS_PERTHREAD=2"
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Setup command to be run before the actual pytest command
|
||||
read -r -d '' setup_cmd <<EOF
|
||||
python3.8 -m pip install --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html \
|
||||
&& python3.8 -m pip install git+https://github.com/google/jax \
|
||||
&& python3.8 -m pip install pytest \
|
||||
&& python3.8 -m pip install pytest-forked \
|
||||
&& mkdir -p /workspace/outputs/
|
||||
EOF
|
||||
|
||||
# Main pytest command that runs the tests
|
||||
read -r -d '' cmd <<EOF
|
||||
date \
|
||||
&& python3.8 -m pip list | grep jax \
|
||||
&& python3.8 -m pytest -m SlurmMultiNodeGpuTest --forked -v -s --continue-on-collection-errors \
|
||||
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
|
||||
/workspace/tests/multiprocess_gpu_test.py
|
||||
EOF
|
||||
|
||||
# create run specific output directory for ease of analysis
|
||||
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# redirect both stdout and stderr in the same file for ease of analysis
|
||||
OUTFILE="${OUTPUT_DIR}/output-test-jaxlib-nightly-%j-%n.txt"
|
||||
|
||||
# Run any setup commands before the actual pytest command to make sure
|
||||
# that the processes are launched together
|
||||
echo $setup_cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--ntasks-per-node=1 \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
$MOUNTS \
|
||||
$EXPORTS \
|
||||
bash -c "${setup_cmd}"
|
||||
|
||||
# Barrier command
|
||||
wait
|
||||
|
||||
# Run the actual pytest command
|
||||
echo $cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--ntasks-per-node=8 \
|
||||
--open-mode=append \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
$MOUNTS \
|
||||
$EXPORTS \
|
||||
bash -c "${cmd}"
|
||||
set +x
|
@ -1,80 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH -A ci-jax-gpu
|
||||
#SBATCH -p compute
|
||||
#SBATCH -N 2 # number of nodes
|
||||
#SBATCH -t 00:15:00 # wall time
|
||||
#SBATCH -J "ci-jax-gpu" # job name
|
||||
#SBATCH --exclusive # exclusive node access
|
||||
#SBATCH --mem=0 # all mem avail
|
||||
#SBATCH --mail-type=FAIL # only send email on failures
|
||||
#SBATCH --overcommit # Needed for pytorch
|
||||
|
||||
set -x
|
||||
|
||||
# File system and volume glue code
|
||||
#-------------------------------------------------------------------------------
|
||||
CONTAINER="nvcr.io/nvidian/jax_t5x:cuda11.4-cudnn8.2-ubuntu20.04-manylinux2014-multipython"
|
||||
CONTAINER_NAME="multinode_ci_test_container"
|
||||
|
||||
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
|
||||
WORKSPACE_DIR=/workspace
|
||||
|
||||
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
|
||||
|
||||
# Since the docker container doesn't contain MLX drivers for IB, following flags
|
||||
# are needed to make NCCL work with an ethernet setup
|
||||
# Note:@sudhakarsingh27 This is very specific, need to abstract this out
|
||||
EXPORTS="--export=ALL,NCCL_SOCKET_IFNAME=enp45s0f0,NCCL_SOCKET_NTHREADS=2,NCCL_NSOCKS_PERTHREAD=2"
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Setup command to be run before the actual pytest command
|
||||
read -r -d '' setup_cmd <<EOF
|
||||
python3.8 -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
|
||||
&& python3.8 -m pip install pytest \
|
||||
&& python3.8 -m pip install pytest-forked \
|
||||
&& mkdir -p /workspace/outputs/
|
||||
EOF
|
||||
|
||||
# Main pytest command that runs the tests
|
||||
read -r -d '' cmd <<EOF
|
||||
date \
|
||||
&& python3.8 -m pip list | grep jax \
|
||||
&& python3.8 -m pytest -m SlurmMultiNodeGpuTest --forked -v -s --continue-on-collection-errors \
|
||||
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
|
||||
/workspace/tests/multiprocess_gpu_test.py
|
||||
EOF
|
||||
|
||||
# create run specific output directory for ease of analysis
|
||||
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# redirect both stdout and stderr in the same file for ease of analysis
|
||||
OUTFILE="${OUTPUT_DIR}/output-test-jaxlib-release-%j-%n.txt"
|
||||
|
||||
# Run any setup commands before the actual pytest command to make sure
|
||||
# that the processes are launched together
|
||||
echo $setup_cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--ntasks-per-node=1 \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
$MOUNTS \
|
||||
$EXPORTS \
|
||||
bash -c "${setup_cmd}"
|
||||
|
||||
# Barrier command
|
||||
wait
|
||||
|
||||
# Run the actual pytest command
|
||||
echo $cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--ntasks-per-node=8 \
|
||||
--open-mode=append \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
$MOUNTS \
|
||||
$EXPORTS \
|
||||
bash -c "${cmd}"
|
||||
set +x
|
165
.github/workflows/slurm_job_scripts/oci_cluster_manager.py
vendored
Normal file
165
.github/workflows/slurm_job_scripts/oci_cluster_manager.py
vendored
Normal file
@ -0,0 +1,165 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A tool to create and destroy clusters on demand.
|
||||
|
||||
To create a cluster:
|
||||
export WEB_API_URL=... # URL of the API endpoint
|
||||
export WEB_API_TOKEN=... # Authentication token
|
||||
python3 oci_cluster_manager.py create_cluster --pubkey "$(cat ~.ssh/id_rsa.pub)"
|
||||
This will first try to find an existing running cluster, and otherwise attempt to create one in any region.
|
||||
When succesfull, the output will contains the headnode hostkeys, the username and ip address of the cluster, or FAILED.
|
||||
|
||||
To create all previously created clusters
|
||||
export WEB_API_URL=... # URL of the API endpoint
|
||||
export WEB_API_TOKEN=... # Authentication token
|
||||
python3 oci_cluster_manager.py destroy_clusters
|
||||
|
||||
This function is used to create and destroy clusters on demand.
|
||||
A few caveats should be noted:
|
||||
- Depending on resource availability, it might not be possible to create a cluster.
|
||||
In that case, the script will eventually fail.
|
||||
- Creating a cluster takes time (30 to 60 mins).
|
||||
Similarly, destroying a cluster also takes time.
|
||||
User should not attempt to concurrently create clusters.
|
||||
As a rule of thumb, this script should only be used at most once every ~12 hours.
|
||||
- In case a pull-request indirectly calls this script, users should take care to ensure
|
||||
no other pipeline is attempting to create a cluster at the same time and within a ~3h time window.
|
||||
- Clusters are automatically destroyed 2h after being created, regardless of whether
|
||||
`destroy_all` is called or not.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import requests
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
_API_URL = os.environ['WEB_API_URL']
|
||||
_API_AUTH = requests.auth.HTTPBasicAuth('login', os.environ['WEB_API_TOKEN'])
|
||||
|
||||
_SLEEP_TIME_SECONDS = int(os.environ.get("SLEEP_TIME_SECONDS", default="30"))
|
||||
_RETRY_PER_REGION = int(os.environ.get("RETRY_PER_REGION", default="3"))
|
||||
_REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS", default=120))
|
||||
|
||||
def get_regions():
|
||||
return requests.post(_API_URL, json={'name':'list_regions'},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['regions']
|
||||
|
||||
def find_existing_cluster():
|
||||
return requests.post(_API_URL, json={'name':'find_cluster'},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['region']
|
||||
|
||||
def get_cluster_ip(region):
|
||||
return requests.post(_API_URL, json={'name':'get_cluster_ip', 'region':region},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['cluster_ip']
|
||||
|
||||
def get_cluster_username(region):
|
||||
return requests.post(_API_URL, json={'name':'get_cluster_username', 'region':region},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['cluster_username']
|
||||
|
||||
def create_cluster(region):
|
||||
logging.debug(requests.post(_API_URL, json={'name':'create_cluster', 'region':region},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS))
|
||||
|
||||
def add_pubkey(region, pubkey):
|
||||
logging.debug(requests.post(_API_URL, json={'name':'add_pubkey', 'region':region, 'pubkey':pubkey},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS))
|
||||
|
||||
def get_cluster_hostkeys(region):
|
||||
return requests.post(_API_URL, json={'name':'get_cluster_hostkeys', 'region':region},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['cluster_hostkeys']
|
||||
|
||||
def get_status(region):
|
||||
return requests.post(_API_URL, json={'name':'get_status', 'region': region},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['status']
|
||||
|
||||
def destroy_all_clusters():
|
||||
logging.debug(requests.post(_API_URL, json={'name':'destroy_all_clusters'},
|
||||
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS))
|
||||
|
||||
def main():
|
||||
|
||||
command_choices = ["create_cluster", "destroy_clusters"]
|
||||
parser = argparse.ArgumentParser(description="Creates and destroy compute clusters on-demand.")
|
||||
parser.add_argument("command", choices=command_choices, \
|
||||
help="""create_cluster will first try to find an existing running cluster,
|
||||
and otherwise attempt to create one in any region.
|
||||
When succesfull, the output will contains the headnode hostkeys, the username and ip address of the cluster, or FAILED.
|
||||
destroy_clusters will destroy all existing clusters.""")
|
||||
parser.add_argument("--pubkey", default=None, \
|
||||
help='public key to upload to the cluster (relevant only with `create_cluster`).')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "create_cluster":
|
||||
found_region = find_existing_cluster()
|
||||
logging.info(f"Found cluster {found_region}")
|
||||
|
||||
if found_region is None:
|
||||
logging.info("Could not find existing cluster. Trying to create one.")
|
||||
regions = get_regions()
|
||||
logging.info(f"Regions considered: {regions}")
|
||||
|
||||
for region in regions * _RETRY_PER_REGION:
|
||||
logging.info(f"Trying region {region}")
|
||||
create_cluster(region)
|
||||
status = get_status(region)
|
||||
|
||||
while status == 'WAIT':
|
||||
logging.info(f"Waiting {_SLEEP_TIME_SECONDS} seconds...")
|
||||
time.sleep(_SLEEP_TIME_SECONDS)
|
||||
status = get_status(region)
|
||||
|
||||
if status == 'SUCCEEDED':
|
||||
logging.info("Successfully allocated cluster")
|
||||
found_region = region
|
||||
break
|
||||
|
||||
else:
|
||||
logging.info("Moving to next region")
|
||||
continue
|
||||
|
||||
else:
|
||||
logging.info(f"Found existing cluster in {found_region}")
|
||||
|
||||
if found_region is not None:
|
||||
logging.info(f"Found cluster in {found_region}")
|
||||
logging.info(f"Adding pubkey {args.pubkey} to cluster")
|
||||
add_pubkey(found_region, args.pubkey)
|
||||
logging.info("Fetching host keys, username and IP address")
|
||||
ip = get_cluster_ip(found_region)
|
||||
username = get_cluster_username(found_region)
|
||||
hostkeys = get_cluster_hostkeys(found_region)
|
||||
print(hostkeys)
|
||||
print(username)
|
||||
print(ip)
|
||||
|
||||
else:
|
||||
logging.info("Failed to allocate cluster")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "destroy_clusters":
|
||||
logging.info("Destroying all")
|
||||
destroy_all_clusters()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Wrong `command` argument. Got {args.command}. Valid choices are {command_choices}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
.github/workflows/slurm_job_scripts/requirements.txt
vendored
Normal file
3
.github/workflows/slurm_job_scripts/requirements.txt
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
requests
|
||||
pandas
|
||||
tabulate
|
96
.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub
vendored
Normal file
96
.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub
vendored
Normal file
@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -x
|
||||
|
||||
echo "GITHUB_WORKSPACE_REMOTE=${GITHUB_WORKSPACE_REMOTE}"
|
||||
TEST_NAME="e2e-t5x"
|
||||
|
||||
# File system and volume glue code
|
||||
#-------------------------------------------------------------------------------
|
||||
CONTAINER="docker://nvcr.io#nvidia/tensorflow:22.11-tf2-py3"
|
||||
CONTAINER_NAME="multinode_ci_test_container_${TEST_NAME}_$RANDOM"
|
||||
|
||||
# create run specific output directory for ease of analysis
|
||||
BASE_OUTPUT_DIR="${GITHUB_WORKSPACE_REMOTE}/outputs/"
|
||||
OUTFILE="${BASE_OUTPUT_DIR}/output-test-${TEST_NAME}-%J-%n-%t.txt"
|
||||
mkdir -p "$BASE_OUTPUT_DIR"
|
||||
|
||||
# Default env variables for paths required by t5x training scripts
|
||||
OUTPUT_DIR=/outputs
|
||||
E2E_TESTS_WORKSPACE_DIR=/localdir/e2e_tests_workspace
|
||||
T5X_DIR="${E2E_TESTS_WORKSPACE_DIR}/t5x"
|
||||
TFDS_DATA_DIR="${E2E_TESTS_WORKSPACE_DIR}/datasets"
|
||||
|
||||
MOUNTS="--container-mounts=$BASE_OUTPUT_DIR:$OUTPUT_DIR"
|
||||
EXPORTS="--export=ALL,PYTHONPATH=${T5X_DIR}"
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Setup command to be run before the actual pytest command
|
||||
read -r -d '' setup_cmd <<EOF
|
||||
rm -rf ${E2E_TESTS_WORKSPACE_DIR}/* \
|
||||
&& mkdir -p ${E2E_TESTS_WORKSPACE_DIR} \
|
||||
&& mkdir -p ${TFDS_DATA_DIR} \
|
||||
&& python3.8 -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
|
||||
&& git clone https://github.com/google-research/t5x.git ${T5X_DIR} \
|
||||
&& python3.8 -m pip install ${T5X_DIR} \
|
||||
&& python3.8 -m pip install ${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config \
|
||||
&& hostname > ${E2E_TESTS_WORKSPACE_DIR}/hostname.txt
|
||||
EOF
|
||||
|
||||
# Main pytest command that runs the tests
|
||||
read -r -d '' cmd <<EOF
|
||||
date \
|
||||
&& python3.8 -m pip list | grep jax \
|
||||
&& python3.8 ${T5X_DIR}/t5x/train.py \
|
||||
--gin_file="${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/small_pretrain_dummy_wikipedia.gin" \
|
||||
--gin.MODEL_DIR=\"${OUTPUT_DIR}/model_dir\" \
|
||||
--gin.network.T5Config.dtype=\"bfloat16\" \
|
||||
--gin.TRAIN_STEPS=100 \
|
||||
--gin.CheckpointConfig.save=None \
|
||||
--multiprocess_gpu
|
||||
EOF
|
||||
|
||||
# Count errors
|
||||
errors=0 && trap errors=1 ERR
|
||||
|
||||
# Run any setup commands before the actual pytest command to make sure
|
||||
# that the processes are launched together
|
||||
echo "$setup_cmd"
|
||||
srun -o "$OUTFILE" -e "$OUTFILE" \
|
||||
-t 00:10:00 \
|
||||
--ntasks-per-node=1 \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name="$CONTAINER_NAME" \
|
||||
"$MOUNTS" \
|
||||
"$EXPORTS" \
|
||||
timeout -v -k 0.5m 9m bash -c "${setup_cmd}"
|
||||
|
||||
# Run the actual pytest command
|
||||
echo "$cmd"
|
||||
srun -o "$OUTFILE" -e "$OUTFILE" \
|
||||
-t 00:15:00 \
|
||||
--ntasks-per-node=8 \
|
||||
--open-mode=append \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name="$CONTAINER_NAME" \
|
||||
"$MOUNTS" \
|
||||
"$EXPORTS" \
|
||||
timeout -v -k 0.5m 14m bash -c "${cmd}"
|
||||
|
||||
test $errors = 0
|
@ -1,100 +0,0 @@
|
||||
#! /bin/bash
|
||||
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# These utility functions are used to monitor SLURM multi-node jobs
|
||||
|
||||
job_exit_code() {
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
sacct -j "${JOBID}" -n --format=exitcode | sort -r -u | head -1 | cut -f 1 -d":" | sed 's/ //g'
|
||||
|
||||
exit ${PIPESTATUS[0]}
|
||||
}
|
||||
|
||||
job_state(){
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
sacct -j "${JOBID}" --format State --parsable2 --noheader |& head -n 1
|
||||
|
||||
exit ${PIPESTATUS[0]}
|
||||
}
|
||||
|
||||
job_nodes(){
|
||||
set -euo pipefail
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
sacct -j "${JOBID}" -X -n --format=nodelist%400 | sed 's/ //g'
|
||||
}
|
||||
|
||||
job_time(){
|
||||
set -euo pipefail
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
## Note: using export so that this line doesn't cause the script to immediately exit if the subshell failed when running under set -e
|
||||
export WALLTIME=$(sacct -j "${JOBID}" --format ElapsedRaw --parsable2 --noheader | head -n 1)
|
||||
|
||||
echo ${WALLTIME:-unknown}
|
||||
}
|
||||
|
||||
job_wait(){
|
||||
set -euo pipefail
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "checking for jobid $1"
|
||||
JOBID="$1"
|
||||
|
||||
while true; do
|
||||
export STATE=$(job_state "${JOBID}")
|
||||
case "${STATE}" in
|
||||
PENDING|RUNNING|REQUEUED)
|
||||
sleep 15s
|
||||
;;
|
||||
*)
|
||||
sleep 30s
|
||||
echo "Exiting with SLURM job status '${STATE}'"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user