mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11721 from sudhakarsingh27:main
PiperOrigin-RevId: 465381834
This commit is contained in:
commit
0a8ca1982c
54
.github/workflows/nightly-ci-multiprocess-gpu.yml
vendored
Normal file
54
.github/workflows/nightly-ci-multiprocess-gpu.yml
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
name: Nightly JAX CI on NVIDIA GPUs
|
||||
# This configures running JAX tests (especially multi-node multi-gpu) against nightly GPU jaxlib builds.
|
||||
# This is expected to fail frequently, and so we don't run it against every commit and PR in the repository.
|
||||
# Portions of this adapted from https://github.com/google/jax/blob/main/.github/workflows/upstream-nightly.yaml
|
||||
|
||||
# Controls when the workflow will run
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 12 * * *" # Daily at 12:00 UTC
|
||||
workflow_dispatch: # allows triggering the workflow run manually
|
||||
pull_request: # Automatically trigger on pull requests affecting this file
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**workflows/nightly-ci-multiprocess-gpu.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Launch slurm job and hook output to this shell
|
||||
run: |
|
||||
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts
|
||||
source $JOBSCRIPTSDIR/slurm_utils_common.sh
|
||||
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest.sub | tee output.log
|
||||
sleep 2m
|
||||
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }')
|
||||
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }')
|
||||
job_wait "${SLURM_JOBID}" & PID=$!
|
||||
touch "${SLURM_OUTPUT}"
|
||||
echo -e " ---------------------------------------------------\n" \
|
||||
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \
|
||||
"---------------------------------------------------\n"
|
||||
tail --pid="${PID}" -f "${SLURM_OUTPUT}"
|
||||
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'"
|
||||
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs"
|
||||
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'"
|
||||
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi
|
||||
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi
|
||||
|
||||
- name: Publish Test Results
|
||||
uses: EnricoMi/publish-unit-test-result-action@v1
|
||||
if: always()
|
||||
with:
|
||||
files: "outputs/*.xml"
|
||||
|
||||
- name: Upload run results from all nodes
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: output-from-nodes
|
||||
path: "outputs/*.txt"
|
68
.github/workflows/slurm_job_scripts/multinode_pytest.sub
vendored
Normal file
68
.github/workflows/slurm_job_scripts/multinode_pytest.sub
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
#SBATCH -A arbitrary
|
||||
#SBATCH -p compute
|
||||
#SBATCH -N 2 # number of nodes
|
||||
#SBATCH -t 01:00:00 # wall time
|
||||
#SBATCH -J "arbitrary" # job name (<< CHANGE ! >>)
|
||||
#SBATCH --exclusive # exclusive node access
|
||||
#SBATCH --mem=0 # all mem avail
|
||||
#SBATCH --mail-type=FAIL # only send email on failure
|
||||
#SBATCH --ntasks-per-node=1 # n tasks per machine
|
||||
#SBATCH --overcommit # Needed for pytorch
|
||||
|
||||
set -x
|
||||
|
||||
# File system and volume glue code
|
||||
#-------------------------------------------------------------------------------
|
||||
CONTAINER="nvcr.io/nvidian/jax_t5x:jax_0.3.14"
|
||||
CONTAINER_NAME="multinode_ci_test_container"
|
||||
|
||||
BASE_WORKSPACE_DIR=$GITHUB_WORKSPACE
|
||||
WORKSPACE_DIR=/workspace
|
||||
|
||||
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:/$WORKSPACE_DIR"
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Setup command to be run before the actual pytest command
|
||||
read -r -d '' setup_cmd <<EOF
|
||||
pip install pytest \
|
||||
&& mkdir -p /workspace/outputs/
|
||||
EOF
|
||||
|
||||
# Main pytest command that runs the tests
|
||||
read -r -d '' cmd <<EOF
|
||||
date \
|
||||
&& pytest -v -s --continue-on-collection-errors \
|
||||
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
|
||||
/workspace/tests/distributed_multinode_test.py
|
||||
EOF
|
||||
|
||||
# create run specific output directory for ease of analysis
|
||||
OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/"
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# redirect both stdout and stderr in the same file for ease of analysis
|
||||
OUTFILE="${OUTPUT_DIR}/output-%j-%n.txt"
|
||||
|
||||
# Run any setup commands before the actual pytest command to make sure
|
||||
# that the processes are launched together
|
||||
echo $setup_cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
$MOUNTS \
|
||||
bash -c "${setup_cmd}"
|
||||
|
||||
# Barrier command
|
||||
wait
|
||||
|
||||
# Run the actual pytest command
|
||||
echo $cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
$MOUNTS \
|
||||
bash -c "${cmd}"
|
||||
set +x
|
100
.github/workflows/slurm_job_scripts/slurm_utils_common.sh
vendored
Executable file
100
.github/workflows/slurm_job_scripts/slurm_utils_common.sh
vendored
Executable file
@ -0,0 +1,100 @@
|
||||
#! /bin/bash
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# These utility functions are used to monitor SLURM multi-node jobs
|
||||
|
||||
job_exit_code() {
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
sacct -j "${JOBID}" -n --format=exitcode | sort -r -u | head -1 | cut -f 1 -d":" | sed 's/ //g'
|
||||
|
||||
exit ${PIPESTATUS[0]}
|
||||
}
|
||||
|
||||
job_state(){
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
sacct -j "${JOBID}" --format State --parsable2 --noheader |& head -n 1
|
||||
|
||||
exit ${PIPESTATUS[0]}
|
||||
}
|
||||
|
||||
job_nodes(){
|
||||
set -euo pipefail
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
sacct -j "${JOBID}" -X -n --format=nodelist%400 | sed 's/ //g'
|
||||
}
|
||||
|
||||
job_time(){
|
||||
set -euo pipefail
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
JOBID="$1"
|
||||
|
||||
## Note: using export so that this line doesn't cause the script to immediately exit if the subshell failed when running under set -e
|
||||
export WALLTIME=$(sacct -j "${JOBID}" --format ElapsedRaw --parsable2 --noheader | head -n 1)
|
||||
|
||||
echo ${WALLTIME:-unknown}
|
||||
}
|
||||
|
||||
job_wait(){
|
||||
set -euo pipefail
|
||||
shopt -s lastpipe
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "checking for jobid $1"
|
||||
JOBID="$1"
|
||||
|
||||
while true; do
|
||||
export STATE=$(job_state "${JOBID}")
|
||||
case "${STATE}" in
|
||||
PENDING|RUNNING|REQUEUED)
|
||||
sleep 15s
|
||||
;;
|
||||
*)
|
||||
sleep 30s
|
||||
echo "Exiting with SLURM job status '${STATE}'"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
}
|
@ -93,6 +93,15 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "distributed_multinode_test",
|
||||
srcs = ["distributed_multinode_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "distributed_test",
|
||||
srcs = ["distributed_test.py"],
|
||||
|
62
tests/distributed_multinode_test.py
Normal file
62
tests/distributed_multinode_test.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import jax
|
||||
import jax._src.lib
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
"Slurm environment with at least two nodes needed!")
|
||||
class MultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gpu_multi_node_initialize_and_psum(self):
|
||||
|
||||
# Hookup the ENV vars expected to be set already in the SLURM environment
|
||||
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
|
||||
if nodelist is not None:
|
||||
coordinator_address = nodelist.split('[')[0] + \
|
||||
nodelist.split('[')[1].split(',')[0]
|
||||
num_tasks = os.environ.get("SLURM_NPROCS", None)
|
||||
taskid = os.environ.get("SLURM_PROCID", None)
|
||||
localid = os.environ.get("SLURM_LOCALID", None)
|
||||
|
||||
# fixing port since it needs to be the same for all the processes
|
||||
port = "54321"
|
||||
|
||||
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
|
||||
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
|
||||
|
||||
self.assertEqual(
|
||||
coordinator_address is None or num_tasks is None or taskid is None,
|
||||
False)
|
||||
|
||||
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
|
||||
num_processes=int(num_tasks),
|
||||
process_id=int(taskid))
|
||||
|
||||
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
|
||||
f"Devices per task: {jax.local_device_count()}")
|
||||
|
||||
self.assertEqual(jax.device_count(),
|
||||
int(num_tasks) * jax.local_device_count())
|
||||
|
||||
x = jnp.ones(jax.local_device_count())
|
||||
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
Loading…
x
Reference in New Issue
Block a user