Multinodes CICD on GPUs using on-demand cluster and e2e tests using T5X

This commit is contained in:
Leopold Cambier 2022-09-26 16:47:38 -07:00
parent 8da6c89c7b
commit 056702c1cb
9 changed files with 440 additions and 372 deletions

View File

@ -1,45 +0,0 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script used in the nightly-ci-multiprocess-gpu workflow to process logs."""
import argparse
import os
from typing import List
ISSUE_FORMAT = """\
<details><summary>Failure summary {name}</summary>
```
{content}
```
</details>
"""
def main(logfiles: List[str], outfile: str):
print(f"extracting content of {logfiles}")
print(f"and writing to {outfile}")
with open(outfile, 'w') as f:
for logfile in logfiles:
content = open(logfile).read()
f.write(ISSUE_FORMAT.format(name=os.path.basename(logfile), content=content))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("logfiles", nargs="+", help="The path to the input logfiles")
parser.add_argument("--outfile", help="The path to the parsed output file to be created.",
default="parsed_logs.txt")
args = parser.parse_args()
main(logfiles=args.logfiles, outfile=args.outfile)

View File

@ -1,91 +1,140 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: Nightly JAX CI on NVIDIA GPUs
# This configures running JAX tests (especially multi-node multi-gpu) against nightly GPU jaxlib builds.
# This is expected to fail frequently, and so we don't run it against every commit and PR in the repository.
# Portions of this adapted from https://github.com/google/jax/blob/main/.github/workflows/upstream-nightly.yaml
# Controls when the workflow will run
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/nightly-ci-multiprocess-gpu.yml'
jobs:
jaxlib-nightly:
runs-on: self-hosted
tests:
runs-on: ubuntu-latest
timeout-minutes: 120
env:
# Secrets to connect and authenticate with the Cluster manager web API
WEB_API_TOKEN: ${{ secrets.NV_CLUSTER_API_TOKEN }}
WEB_API_URL: ${{ secrets.NV_CLUSTER_API_URL }}
GITHUB_WORKSPACE_REMOTE: "~/jax_ci_${{ github.run_id }}_${{ github.run_attempt }}/"
CONFIG: "-F ${{ github.workspace }}/.ssh/config -o UserKnownHostsFile=${{ github.workspace }}/.ssh/known_hosts"
steps:
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
- name: Launch slurm job and hook output to this shell
run: |
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
source $JOBSCRIPTSDIR/slurm_utils_common.sh
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest_jaxlib_nightly.sub | tee output.log
sleep 2m
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
job_wait "${SLURM_JOBID}" & PID=$!
touch "${SLURM_OUTPUT}"
echo -e " ---------------------------------------------------\n" \
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
"---------------------------------------------------\n"
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@46ab8d49369d898e381a607119161771bc65c2a6 # ratchet:EnricoMi/publish-unit-test-result-action@v2
if: always()
- uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
with:
junit_files: "outputs/*.xml"
python-version: "3.x"
- name: Create credentials
run: |
# Create SSH keys
mkdir -p ./.ssh && chmod 700 ./.ssh
ssh-keygen -N '' -f ./.ssh/id_rsa
- name: Create cluster
run: |
# Setup cluster, get username and IP address
pip install -r "./.github/workflows/slurm_job_scripts/requirements.txt"
python3 ./.github/workflows/slurm_job_scripts/oci_cluster_manager.py create_cluster --pubkey "$(cat ./.ssh/id_rsa.pub)" &> oci_automation_create.log
USER=$(tail -n 2 oci_automation_create.log | head -n 1)
IP=$(tail -n 1 oci_automation_create.log)
# Hide IP address from logs
echo "::add-mask::${IP}"
# Create SSH config
grep "^${IP} " oci_automation_create.log >> ./.ssh/known_hosts
echo "Host headnode
User ${USER}
HostName ${IP}
IdentityFile ${GITHUB_WORKSPACE}/.ssh/id_rsa" > ./.ssh/config
- name: Check SLURM is working
run: |
# SSH into the cluser & check SLURM
ssh ${CONFIG} headnode sinfo
# Run dummy job
SRUN="srun --container-name=nvidia --container-image=docker://nvcr.io#nvidia/tensorflow:22.11-tf2-py3 -N 2 -t 15:00 --gpus-per-node=8 --cpus-per-task=8 --ntasks-per-node=8"
CMD="bash -c 'hostname && nvidia-smi --query-gpu=gpu_name,driver_version --format=csv'"
ssh ${CONFIG} headnode "${SRUN} ${CMD}"
- name: Copy workspace
run: |
ssh ${CONFIG} headnode "rm -rf ${GITHUB_WORKSPACE_REMOTE} && mkdir -p ${GITHUB_WORKSPACE_REMOTE}"
scp ${CONFIG} -r ./.github headnode:${GITHUB_WORKSPACE_REMOTE}/.github
scp ${CONFIG} -r ./tests headnode:${GITHUB_WORKSPACE_REMOTE}/tests
scp ${CONFIG} -r ./pytest* headnode:${GITHUB_WORKSPACE_REMOTE}/
- name: T5X end-to-end tests
timeout-minutes: 25
run: |
ENV="GITHUB_WORKSPACE_REMOTE=${GITHUB_WORKSPACE_REMOTE}"
SALLOC="salloc -N 2 --gpus-per-node=8 --exclusive -t 0:20:00 -p compute"
CMD="bash ${GITHUB_WORKSPACE_REMOTE}/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub"
ssh ${CONFIG} headnode "${ENV} ${SALLOC} ${CMD}"
- name: Gather results
if: always()
run: |
scp ${CONFIG} -r headnode:${GITHUB_WORKSPACE_REMOTE}/outputs ./
- name: Destroy cluster
if: always()
run: |
pip install -r "./.github/workflows/slurm_job_scripts/requirements.txt"
python3 ./.github/workflows/slurm_job_scripts/oci_cluster_manager.py destroy_clusters &> ./oci_automation_destroy.log
- name: Upload run results from all nodes
uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # ratchet:actions/upload-artifact@v3
if: always()
with:
name: output-from-nodes
path: "outputs/*.txt"
jaxlib-release:
runs-on: self-hosted
needs: jaxlib-nightly
report-metrics:
name: e2e-tests-metrics
needs: tests
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
- name: Launch slurm job and hook output to this shell
- uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
with:
python-version: "3.x"
- uses: actions/download-artifact@9782bd6a9848b53b110e712e20e42d89988822b7 # ratchet:actions/download-artifact@v3
with:
path: /tmp/workspace/logs
- name: Parse log output
run: |
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
source $JOBSCRIPTSDIR/slurm_utils_common.sh
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest_jaxlib_release.sub | tee output.log
sleep 2m
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
job_wait "${SLURM_JOBID}" & PID=$!
touch "${SLURM_OUTPUT}"
echo -e " ---------------------------------------------------\n" \
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
"---------------------------------------------------\n"
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@46ab8d49369d898e381a607119161771bc65c2a6 # ratchet:EnricoMi/publish-unit-test-result-action@v2
if: always()
with:
junit_files: "outputs/*.xml"
- name: Upload run results from all nodes
uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # ratchet:actions/upload-artifact@v3
if: always()
with:
name: output-from-nodes
path: "outputs/*.txt"
ls /tmp/workspace/logs/output-from-nodes/ && mv /tmp/workspace/logs/output-from-nodes/output*t5x*1-0-0.txt ${GITHUB_WORKSPACE}/output.log
pip install -r "${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts/requirements.txt"
python ${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts/extract_e2e_tests_metrics.py --logfile ${GITHUB_WORKSPACE}/output.log --outmd ${GITHUB_WORKSPACE}/report.md --outjson ${GITHUB_WORKSPACE}/report.json --name end-to-end-t5x
cat report.md >> $GITHUB_STEP_SUMMARY
report:
name: report
needs: [jaxlib-nightly, jaxlib-release]
needs: tests
if: |
failure()
&& github.event_name == 'schedule'

View File

@ -0,0 +1,61 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script used in the nightly-ci-multiprocess-gpu workflow to process logs."""
import argparse
import os
import re
import json
import datetime
import pandas as pd
stats_pat = r".*collection=train .*(timing/seconds=[\d.]+), (timing/seqs=[\d.]+), (timing/seqs_per_second=[\d.]+), (timing/seqs_per_second_per_core=[\d.]+), (timing/steps_per_second=[\d.]+), (timing/target_tokens_per_second=[\d.]+), (timing/target_tokens_per_second_per_core=[\d.]+).*"
def main(logfile: str, outmd: str, outjson: str, name: str):
print(f"Extracting content of {logfile}")
print(f"and writing to {outmd} and {outjson}")
with open(logfile, 'r') as fp:
lines = fp.read()
stats = re.findall(stats_pat,lines)
data_parsed = [
# Extract `metric` and `value` from `timings/metric=value`
{re.split('=|/',s)[1] : float(re.split('=|/',s)[2]) for s in stat}
for stat in stats
]
df = pd.DataFrame(data_parsed).reset_index(drop=True)
df.to_markdown(outmd, index=False)
data = {
'name': name,
'date': datetime.datetime.now(tz=None).isoformat(),
'data': data_parsed,
'github': {k:v for (k,v) in os.environ.items() if k.startswith('GITHUB')}
}
with open(outjson, "w") as ofile:
ofile.write(json.dumps(data, indent=4))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--logfile", help="The path to the input logfile")
parser.add_argument("--outmd", help="The path to the parsed output markdown file to be created.",
default="metrics_report.md")
parser.add_argument("--outjson", help="The path to the parsed output json file to be created.",
default="metrics_report.json")
parser.add_argument("--name", help="Name of the benchmark to be added to the JSON.")
args = parser.parse_args()
main(logfile=args.logfile, outmd=args.outmd, outjson=args.outjson, name=args.name)

View File

@ -1,81 +0,0 @@
#!/bin/bash
#SBATCH -A ci-jax-gpu
#SBATCH -p compute
#SBATCH -N 2 # number of nodes
#SBATCH -t 00:15:00 # wall time
#SBATCH -J "ci-jax-gpu" # job name
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
#SBATCH --mail-type=FAIL # only send email on failures
#SBATCH --overcommit # Needed for pytorch
set -x
# File system and volume glue code
#-------------------------------------------------------------------------------
CONTAINER="nvcr.io/nvidian/jax_t5x:cuda11.4-cudnn8.2-ubuntu20.04-manylinux2014-multipython"
CONTAINER_NAME="multinode_ci_test_container"
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
WORKSPACE_DIR=/workspace
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
# Since the docker container doesn't contain MLX drivers for IB, following flags
# are needed to make NCCL work with an ethernet setup
# Note:@sudhakarsingh27 This is very specific, need to abstract this out
EXPORTS="--export=ALL,NCCL_SOCKET_IFNAME=enp45s0f0,NCCL_SOCKET_NTHREADS=2,NCCL_NSOCKS_PERTHREAD=2"
#-------------------------------------------------------------------------------
# Setup command to be run before the actual pytest command
read -r -d '' setup_cmd <<EOF
python3.8 -m pip install --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html \
&& python3.8 -m pip install git+https://github.com/google/jax \
&& python3.8 -m pip install pytest \
&& python3.8 -m pip install pytest-forked \
&& mkdir -p /workspace/outputs/
EOF
# Main pytest command that runs the tests
read -r -d '' cmd <<EOF
date \
&& python3.8 -m pip list | grep jax \
&& python3.8 -m pytest -m SlurmMultiNodeGpuTest --forked -v -s --continue-on-collection-errors \
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
/workspace/tests/multiprocess_gpu_test.py
EOF
# create run specific output directory for ease of analysis
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
mkdir -p $OUTPUT_DIR
# redirect both stdout and stderr in the same file for ease of analysis
OUTFILE="${OUTPUT_DIR}/output-test-jaxlib-nightly-%j-%n.txt"
# Run any setup commands before the actual pytest command to make sure
# that the processes are launched together
echo $setup_cmd
srun -o $OUTFILE -e $OUTFILE \
--ntasks-per-node=1 \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
$EXPORTS \
bash -c "${setup_cmd}"
# Barrier command
wait
# Run the actual pytest command
echo $cmd
srun -o $OUTFILE -e $OUTFILE \
--ntasks-per-node=8 \
--open-mode=append \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
$EXPORTS \
bash -c "${cmd}"
set +x

View File

@ -1,80 +0,0 @@
#!/bin/bash
#SBATCH -A ci-jax-gpu
#SBATCH -p compute
#SBATCH -N 2 # number of nodes
#SBATCH -t 00:15:00 # wall time
#SBATCH -J "ci-jax-gpu" # job name
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
#SBATCH --mail-type=FAIL # only send email on failures
#SBATCH --overcommit # Needed for pytorch
set -x
# File system and volume glue code
#-------------------------------------------------------------------------------
CONTAINER="nvcr.io/nvidian/jax_t5x:cuda11.4-cudnn8.2-ubuntu20.04-manylinux2014-multipython"
CONTAINER_NAME="multinode_ci_test_container"
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
WORKSPACE_DIR=/workspace
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
# Since the docker container doesn't contain MLX drivers for IB, following flags
# are needed to make NCCL work with an ethernet setup
# Note:@sudhakarsingh27 This is very specific, need to abstract this out
EXPORTS="--export=ALL,NCCL_SOCKET_IFNAME=enp45s0f0,NCCL_SOCKET_NTHREADS=2,NCCL_NSOCKS_PERTHREAD=2"
#-------------------------------------------------------------------------------
# Setup command to be run before the actual pytest command
read -r -d '' setup_cmd <<EOF
python3.8 -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& python3.8 -m pip install pytest \
&& python3.8 -m pip install pytest-forked \
&& mkdir -p /workspace/outputs/
EOF
# Main pytest command that runs the tests
read -r -d '' cmd <<EOF
date \
&& python3.8 -m pip list | grep jax \
&& python3.8 -m pytest -m SlurmMultiNodeGpuTest --forked -v -s --continue-on-collection-errors \
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
/workspace/tests/multiprocess_gpu_test.py
EOF
# create run specific output directory for ease of analysis
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
mkdir -p $OUTPUT_DIR
# redirect both stdout and stderr in the same file for ease of analysis
OUTFILE="${OUTPUT_DIR}/output-test-jaxlib-release-%j-%n.txt"
# Run any setup commands before the actual pytest command to make sure
# that the processes are launched together
echo $setup_cmd
srun -o $OUTFILE -e $OUTFILE \
--ntasks-per-node=1 \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
$EXPORTS \
bash -c "${setup_cmd}"
# Barrier command
wait
# Run the actual pytest command
echo $cmd
srun -o $OUTFILE -e $OUTFILE \
--ntasks-per-node=8 \
--open-mode=append \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
$EXPORTS \
bash -c "${cmd}"
set +x

View File

@ -0,0 +1,165 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A tool to create and destroy clusters on demand.
To create a cluster:
export WEB_API_URL=... # URL of the API endpoint
export WEB_API_TOKEN=... # Authentication token
python3 oci_cluster_manager.py create_cluster --pubkey "$(cat ~.ssh/id_rsa.pub)"
This will first try to find an existing running cluster, and otherwise attempt to create one in any region.
When succesfull, the output will contains the headnode hostkeys, the username and ip address of the cluster, or FAILED.
To create all previously created clusters
export WEB_API_URL=... # URL of the API endpoint
export WEB_API_TOKEN=... # Authentication token
python3 oci_cluster_manager.py destroy_clusters
This function is used to create and destroy clusters on demand.
A few caveats should be noted:
- Depending on resource availability, it might not be possible to create a cluster.
In that case, the script will eventually fail.
- Creating a cluster takes time (30 to 60 mins).
Similarly, destroying a cluster also takes time.
User should not attempt to concurrently create clusters.
As a rule of thumb, this script should only be used at most once every ~12 hours.
- In case a pull-request indirectly calls this script, users should take care to ensure
no other pipeline is attempting to create a cluster at the same time and within a ~3h time window.
- Clusters are automatically destroyed 2h after being created, regardless of whether
`destroy_all` is called or not.
"""
import os
import argparse
import time
import logging
import sys
import requests
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
_API_URL = os.environ['WEB_API_URL']
_API_AUTH = requests.auth.HTTPBasicAuth('login', os.environ['WEB_API_TOKEN'])
_SLEEP_TIME_SECONDS = int(os.environ.get("SLEEP_TIME_SECONDS", default="30"))
_RETRY_PER_REGION = int(os.environ.get("RETRY_PER_REGION", default="3"))
_REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS", default=120))
def get_regions():
return requests.post(_API_URL, json={'name':'list_regions'},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['regions']
def find_existing_cluster():
return requests.post(_API_URL, json={'name':'find_cluster'},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['region']
def get_cluster_ip(region):
return requests.post(_API_URL, json={'name':'get_cluster_ip', 'region':region},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['cluster_ip']
def get_cluster_username(region):
return requests.post(_API_URL, json={'name':'get_cluster_username', 'region':region},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['cluster_username']
def create_cluster(region):
logging.debug(requests.post(_API_URL, json={'name':'create_cluster', 'region':region},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS))
def add_pubkey(region, pubkey):
logging.debug(requests.post(_API_URL, json={'name':'add_pubkey', 'region':region, 'pubkey':pubkey},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS))
def get_cluster_hostkeys(region):
return requests.post(_API_URL, json={'name':'get_cluster_hostkeys', 'region':region},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['cluster_hostkeys']
def get_status(region):
return requests.post(_API_URL, json={'name':'get_status', 'region': region},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS).json()['status']
def destroy_all_clusters():
logging.debug(requests.post(_API_URL, json={'name':'destroy_all_clusters'},
auth=_API_AUTH, timeout=_REQUEST_TIMEOUT_SECONDS))
def main():
command_choices = ["create_cluster", "destroy_clusters"]
parser = argparse.ArgumentParser(description="Creates and destroy compute clusters on-demand.")
parser.add_argument("command", choices=command_choices, \
help="""create_cluster will first try to find an existing running cluster,
and otherwise attempt to create one in any region.
When succesfull, the output will contains the headnode hostkeys, the username and ip address of the cluster, or FAILED.
destroy_clusters will destroy all existing clusters.""")
parser.add_argument("--pubkey", default=None, \
help='public key to upload to the cluster (relevant only with `create_cluster`).')
args = parser.parse_args()
if args.command == "create_cluster":
found_region = find_existing_cluster()
logging.info(f"Found cluster {found_region}")
if found_region is None:
logging.info("Could not find existing cluster. Trying to create one.")
regions = get_regions()
logging.info(f"Regions considered: {regions}")
for region in regions * _RETRY_PER_REGION:
logging.info(f"Trying region {region}")
create_cluster(region)
status = get_status(region)
while status == 'WAIT':
logging.info(f"Waiting {_SLEEP_TIME_SECONDS} seconds...")
time.sleep(_SLEEP_TIME_SECONDS)
status = get_status(region)
if status == 'SUCCEEDED':
logging.info("Successfully allocated cluster")
found_region = region
break
else:
logging.info("Moving to next region")
continue
else:
logging.info(f"Found existing cluster in {found_region}")
if found_region is not None:
logging.info(f"Found cluster in {found_region}")
logging.info(f"Adding pubkey {args.pubkey} to cluster")
add_pubkey(found_region, args.pubkey)
logging.info("Fetching host keys, username and IP address")
ip = get_cluster_ip(found_region)
username = get_cluster_username(found_region)
hostkeys = get_cluster_hostkeys(found_region)
print(hostkeys)
print(username)
print(ip)
else:
logging.info("Failed to allocate cluster")
sys.exit(1)
elif args.command == "destroy_clusters":
logging.info("Destroying all")
destroy_all_clusters()
else:
raise ValueError(f"Wrong `command` argument. Got {args.command}. Valid choices are {command_choices}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,3 @@
requests
pandas
tabulate

View File

@ -0,0 +1,96 @@
#!/bin/bash
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -x
echo "GITHUB_WORKSPACE_REMOTE=${GITHUB_WORKSPACE_REMOTE}"
TEST_NAME="e2e-t5x"
# File system and volume glue code
#-------------------------------------------------------------------------------
CONTAINER="docker://nvcr.io#nvidia/tensorflow:22.11-tf2-py3"
CONTAINER_NAME="multinode_ci_test_container_${TEST_NAME}_$RANDOM"
# create run specific output directory for ease of analysis
BASE_OUTPUT_DIR="${GITHUB_WORKSPACE_REMOTE}/outputs/"
OUTFILE="${BASE_OUTPUT_DIR}/output-test-${TEST_NAME}-%J-%n-%t.txt"
mkdir -p "$BASE_OUTPUT_DIR"
# Default env variables for paths required by t5x training scripts
OUTPUT_DIR=/outputs
E2E_TESTS_WORKSPACE_DIR=/localdir/e2e_tests_workspace
T5X_DIR="${E2E_TESTS_WORKSPACE_DIR}/t5x"
TFDS_DATA_DIR="${E2E_TESTS_WORKSPACE_DIR}/datasets"
MOUNTS="--container-mounts=$BASE_OUTPUT_DIR:$OUTPUT_DIR"
EXPORTS="--export=ALL,PYTHONPATH=${T5X_DIR}"
#-------------------------------------------------------------------------------
# Setup command to be run before the actual pytest command
read -r -d '' setup_cmd <<EOF
rm -rf ${E2E_TESTS_WORKSPACE_DIR}/* \
&& mkdir -p ${E2E_TESTS_WORKSPACE_DIR} \
&& mkdir -p ${TFDS_DATA_DIR} \
&& python3.8 -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& git clone https://github.com/google-research/t5x.git ${T5X_DIR} \
&& python3.8 -m pip install ${T5X_DIR} \
&& python3.8 -m pip install ${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config \
&& hostname > ${E2E_TESTS_WORKSPACE_DIR}/hostname.txt
EOF
# Main pytest command that runs the tests
read -r -d '' cmd <<EOF
date \
&& python3.8 -m pip list | grep jax \
&& python3.8 ${T5X_DIR}/t5x/train.py \
--gin_file="${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/small_pretrain_dummy_wikipedia.gin" \
--gin.MODEL_DIR=\"${OUTPUT_DIR}/model_dir\" \
--gin.network.T5Config.dtype=\"bfloat16\" \
--gin.TRAIN_STEPS=100 \
--gin.CheckpointConfig.save=None \
--multiprocess_gpu
EOF
# Count errors
errors=0 && trap errors=1 ERR
# Run any setup commands before the actual pytest command to make sure
# that the processes are launched together
echo "$setup_cmd"
srun -o "$OUTFILE" -e "$OUTFILE" \
-t 00:10:00 \
--ntasks-per-node=1 \
--container-writable \
--container-image="$CONTAINER" \
--container-name="$CONTAINER_NAME" \
"$MOUNTS" \
"$EXPORTS" \
timeout -v -k 0.5m 9m bash -c "${setup_cmd}"
# Run the actual pytest command
echo "$cmd"
srun -o "$OUTFILE" -e "$OUTFILE" \
-t 00:15:00 \
--ntasks-per-node=8 \
--open-mode=append \
--container-writable \
--container-image="$CONTAINER" \
--container-name="$CONTAINER_NAME" \
"$MOUNTS" \
"$EXPORTS" \
timeout -v -k 0.5m 14m bash -c "${cmd}"
test $errors = 0

View File

@ -1,100 +0,0 @@
#! /bin/bash
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# These utility functions are used to monitor SLURM multi-node jobs
job_exit_code() {
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" -n --format=exitcode | sort -r -u | head -1 | cut -f 1 -d":" | sed 's/ //g'
exit ${PIPESTATUS[0]}
}
job_state(){
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" --format State --parsable2 --noheader |& head -n 1
exit ${PIPESTATUS[0]}
}
job_nodes(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" -X -n --format=nodelist%400 | sed 's/ //g'
}
job_time(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
## Note: using export so that this line doesn't cause the script to immediately exit if the subshell failed when running under set -e
export WALLTIME=$(sacct -j "${JOBID}" --format ElapsedRaw --parsable2 --noheader | head -n 1)
echo ${WALLTIME:-unknown}
}
job_wait(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
echo "checking for jobid $1"
JOBID="$1"
while true; do
export STATE=$(job_state "${JOBID}")
case "${STATE}" in
PENDING|RUNNING|REQUEUED)
sleep 15s
;;
*)
sleep 30s
echo "Exiting with SLURM job status '${STATE}'"
exit 0
;;
esac
done
}