Squashed commit of the following:

commit 0b4c3f05a49037be93eb0612113e193f3a8d61c5
Author: Sudhakar Singh <sudhakars@nvidia.com>
Date:   Thu Aug 4 09:53:04 2022 -0700

    change the path

commit 2c629739c1cfa45d848a2cf7109d329c1262e6ac
Author: Sudhakar Singh <sudhakars@nvidia.com>
Date:   Wed Aug 3 16:37:46 2022 -0700

    rename file to reflect current objective

commit ef46bcae6cd66d6fe7b04bd6d8aeed42c4f3ddfa
Author: Sudhakar Singh <sudhakars@nvidia.com>
Date:   Wed Aug 3 15:56:32 2022 -0700

    correct formatting

commit e5da60ad855592d5f150612f65ad679872160132
Author: Sudhakar Singh <sudhakars@nvidia.com>
Date:   Wed Aug 3 15:26:32 2022 -0700

    Add multi-node multi-GPU JAX tests

    This adds multi-node multi-GPU test for `jax.distributed.initialize`.
    Presently, this is expected to run on a nightly basis. Under the hood,
    SLURM is used to launch the `pytest <test_name>` commands on multiple
    nodes.

    Resolves: #11648
This commit is contained in:
Sudhakar Singh 2022-08-04 10:13:50 -07:00
parent a3ad01a9be
commit 1565fd2525
4 changed files with 284 additions and 0 deletions

View File

@ -0,0 +1,54 @@
name: Nightly JAX CI on NVIDIA GPUs
# This configures running JAX tests (especially multi-node multi-gpu) against nightly GPU jaxlib builds.
# This is expected to fail frequently, and so we don't run it against every commit and PR in the repository.
# Portions of this adapted from https://github.com/google/jax/blob/main/.github/workflows/upstream-nightly.yaml
# Controls when the workflow will run
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/nightly-ci-multiprocess-gpu.yml'
jobs:
build:
runs-on: self-hosted
steps:
- uses: actions/checkout@v3
- name: Launch slurm job and hook output to this shell
run: |
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
source $JOBSCRIPTSDIR/slurm_utils_common.sh
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest.sub | tee output.log
sleep 2m
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
job_wait "${SLURM_JOBID}" & PID=$!
touch "${SLURM_OUTPUT}"
echo -e " ---------------------------------------------------\n" \
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
"---------------------------------------------------\n"
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@v1
if: always()
with:
files: "outputs/*.xml"
- name: Upload run results from all nodes
uses: actions/upload-artifact@v3
if: always()
with:
name: output-from-nodes
path: "outputs/*.txt"

View File

@ -0,0 +1,68 @@
#!/bin/bash
#SBATCH -A arbitrary
#SBATCH -p compute
#SBATCH -N 2 # number of nodes
#SBATCH -t 01:00:00 # wall time
#SBATCH -J "arbitrary" # job name (<< CHANGE ! >>)
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
#SBATCH --mail-type=FAIL # only send email on failure
#SBATCH --ntasks-per-node=1 # n tasks per machine
#SBATCH --overcommit # Needed for pytorch
set -x
# File system and volume glue code
#-------------------------------------------------------------------------------
CONTAINER="nvcr.io/nvidian/jax_t5x:jax_0.3.14"
CONTAINER_NAME="multinode_ci_test_container"
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
WORKSPACE_DIR=/workspace
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
#-------------------------------------------------------------------------------
# Setup command to be run before the actual pytest command
read -r -d '' setup_cmd <<EOF
pip install pytest \
&& mkdir -p /workspace/outputs/
EOF
# Main pytest command that runs the tests
read -r -d '' cmd <<EOF
date \
&& pytest -v -s --continue-on-collection-errors \
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
/workspace/tests/distributed_multinode_test.py
EOF
# create run specific output directory for ease of analysis
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
mkdir -p $OUTPUT_DIR
# redirect both stdout and stderr in the same file for ease of analysis
OUTFILE="${OUTPUT_DIR}/output-%j-%n.txt"
# Run any setup commands before the actual pytest command to make sure
# that the processes are launched together
echo $setup_cmd
srun -o $OUTFILE -e $OUTFILE \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
bash -c "${setup_cmd}"
# Barrier command
wait
# Run the actual pytest command
echo $cmd
srun -o $OUTFILE -e $OUTFILE \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
bash -c "${cmd}"
set +x

View File

@ -0,0 +1,100 @@
#! /bin/bash
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# These utility functions are used to monitor SLURM multi-node jobs
job_exit_code() {
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" -n --format=exitcode | sort -r -u | head -1 | cut -f 1 -d":" | sed 's/ //g'
exit ${PIPESTATUS[0]}
}
job_state(){
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" --format State --parsable2 --noheader |& head -n 1
exit ${PIPESTATUS[0]}
}
job_nodes(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" -X -n --format=nodelist%400 | sed 's/ //g'
}
job_time(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
## Note: using export so that this line doesn't cause the script to immediately exit if the subshell failed when running under set -e
export WALLTIME=$(sacct -j "${JOBID}" --format ElapsedRaw --parsable2 --noheader | head -n 1)
echo ${WALLTIME:-unknown}
}
job_wait(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
echo "checking for jobid $1"
JOBID="$1"
while true; do
export STATE=$(job_state "${JOBID}")
case "${STATE}" in
PENDING|RUNNING|REQUEUED)
sleep 15s
;;
*)
sleep 30s
echo "Exiting with SLURM job status '${STATE}'"
exit 0
;;
esac
done
}

View File

@ -0,0 +1,62 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import jax
import jax._src.lib
from jax._src import test_util as jtu
import jax.numpy as jnp
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
class MultiNodeGpuTest(jtu.JaxTestCase):
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
if nodelist is not None:
coordinator_address = nodelist.split('[')[0] + \
nodelist.split('[')[1].split(',')[0]
num_tasks = os.environ.get("SLURM_NPROCS", None)
taskid = os.environ.get("SLURM_PROCID", None)
localid = os.environ.get("SLURM_LOCALID", None)
# fixing port since it needs to be the same for all the processes
port = "54321"
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
self.assertEqual(
coordinator_address is None or num_tasks is None or taskid is None,
False)
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
num_processes=int(num_tasks),
process_id=int(taskid))
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(),
int(num_tasks) * jax.local_device_count())
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)