Merge pull request #11721 from sudhakarsingh27:main

PiperOrigin-RevId: 465381834
This commit is contained in:
jax authors 2022-08-04 12:52:16 -07:00
commit 0a8ca1982c
5 changed files with 293 additions and 0 deletions

View File

@ -0,0 +1,54 @@
name: Nightly JAX CI on NVIDIA GPUs
# This configures running JAX tests (especially multi-node multi-gpu) against nightly GPU jaxlib builds.
# This is expected to fail frequently, and so we don't run it against every commit and PR in the repository.
# Portions of this adapted from https://github.com/google/jax/blob/main/.github/workflows/upstream-nightly.yaml
# Controls when the workflow will run
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/nightly-ci-multiprocess-gpu.yml'
jobs:
build:
runs-on: self-hosted
steps:
- uses: actions/checkout@v3
- name: Launch slurm job and hook output to this shell
run: |
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
source $JOBSCRIPTSDIR/slurm_utils_common.sh
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest.sub | tee output.log
sleep 2m
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
job_wait "${SLURM_JOBID}" & PID=$!
touch "${SLURM_OUTPUT}"
echo -e " ---------------------------------------------------\n" \
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
"---------------------------------------------------\n"
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@v1
if: always()
with:
files: "outputs/*.xml"
- name: Upload run results from all nodes
uses: actions/upload-artifact@v3
if: always()
with:
name: output-from-nodes
path: "outputs/*.txt"

View File

@ -0,0 +1,68 @@
#!/bin/bash
#SBATCH -A arbitrary
#SBATCH -p compute
#SBATCH -N 2 # number of nodes
#SBATCH -t 01:00:00 # wall time
#SBATCH -J "arbitrary" # job name (<< CHANGE ! >>)
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
#SBATCH --mail-type=FAIL # only send email on failure
#SBATCH --ntasks-per-node=1 # n tasks per machine
#SBATCH --overcommit # Needed for pytorch
set -x
# File system and volume glue code
#-------------------------------------------------------------------------------
CONTAINER="nvcr.io/nvidian/jax_t5x:jax_0.3.14"
CONTAINER_NAME="multinode_ci_test_container"
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
WORKSPACE_DIR=/workspace
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
#-------------------------------------------------------------------------------
# Setup command to be run before the actual pytest command
read -r -d '' setup_cmd <<EOF
pip install pytest \
&& mkdir -p /workspace/outputs/
EOF
# Main pytest command that runs the tests
read -r -d '' cmd <<EOF
date \
&& pytest -v -s --continue-on-collection-errors \
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
/workspace/tests/distributed_multinode_test.py
EOF
# create run specific output directory for ease of analysis
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
mkdir -p $OUTPUT_DIR
# redirect both stdout and stderr in the same file for ease of analysis
OUTFILE="${OUTPUT_DIR}/output-%j-%n.txt"
# Run any setup commands before the actual pytest command to make sure
# that the processes are launched together
echo $setup_cmd
srun -o $OUTFILE -e $OUTFILE \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
bash -c "${setup_cmd}"
# Barrier command
wait
# Run the actual pytest command
echo $cmd
srun -o $OUTFILE -e $OUTFILE \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
$MOUNTS \
bash -c "${cmd}"
set +x

View File

@ -0,0 +1,100 @@
#! /bin/bash
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# These utility functions are used to monitor SLURM multi-node jobs
job_exit_code() {
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" -n --format=exitcode | sort -r -u | head -1 | cut -f 1 -d":" | sed 's/ //g'
exit ${PIPESTATUS[0]}
}
job_state(){
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" --format State --parsable2 --noheader |& head -n 1
exit ${PIPESTATUS[0]}
}
job_nodes(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
sacct -j "${JOBID}" -X -n --format=nodelist%400 | sed 's/ //g'
}
job_time(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
JOBID="$1"
## Note: using export so that this line doesn't cause the script to immediately exit if the subshell failed when running under set -e
export WALLTIME=$(sacct -j "${JOBID}" --format ElapsedRaw --parsable2 --noheader | head -n 1)
echo ${WALLTIME:-unknown}
}
job_wait(){
set -euo pipefail
shopt -s lastpipe
if [ "$#" -ne 1 ]; then
exit 1
fi
echo "checking for jobid $1"
JOBID="$1"
while true; do
export STATE=$(job_state "${JOBID}")
case "${STATE}" in
PENDING|RUNNING|REQUEUED)
sleep 15s
;;
*)
sleep 30s
echo "Exiting with SLURM job status '${STATE}'"
exit 0
;;
esac
done
}

View File

@ -93,6 +93,15 @@ py_test(
],
)
py_test(
name = "distributed_multinode_test",
srcs = ["distributed_multinode_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "distributed_test",
srcs = ["distributed_test.py"],

View File

@ -0,0 +1,62 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import jax
import jax._src.lib
from jax._src import test_util as jtu
import jax.numpy as jnp
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
class MultiNodeGpuTest(jtu.JaxTestCase):
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
if nodelist is not None:
coordinator_address = nodelist.split('[')[0] + \
nodelist.split('[')[1].split(',')[0]
num_tasks = os.environ.get("SLURM_NPROCS", None)
taskid = os.environ.get("SLURM_PROCID", None)
localid = os.environ.get("SLURM_LOCALID", None)
# fixing port since it needs to be the same for all the processes
port = "54321"
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
self.assertEqual(
coordinator_address is None or num_tasks is None or taskid is None,
False)
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
num_processes=int(num_tasks),
process_id=int(taskid))
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(),
int(num_tasks) * jax.local_device_count())
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)