Enable one gpu per process in multinode GPU CI

This commit is contained in:
Sudhakar 2022-08-29 09:00:19 -07:00
parent b9d7e05eda
commit a571db18db
2 changed files with 11 additions and 7 deletions

View File

@ -6,8 +6,7 @@
#SBATCH -J "ci-jax-gpu" # job name
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
#SBATCH --mail-type=FAIL # only send email on failure
#SBATCH --ntasks-per-node=1 # 1 tasks per machine for now
#SBATCH --mail-type=FAIL # only send email on failures
#SBATCH --overcommit # Needed for pytorch
set -x
@ -57,6 +56,7 @@ OUTFILE="${OUTPUT_DIR}/output-%j-%n.txt"
# that the processes are launched together
echo $setup_cmd
srun -o $OUTFILE -e $OUTFILE \
--ntasks-per-node=1 \
--container-writable \
--container-image="$CONTAINER" \
--container-name=$CONTAINER_NAME \
@ -70,6 +70,7 @@ wait
# Run the actual pytest command
echo $cmd
srun -o $OUTFILE -e $OUTFILE \
--ntasks-per-node=8 \
--open-mode=append \
--container-writable \
--container-image="$CONTAINER" \

View File

@ -151,15 +151,15 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
class MultiNodeGpuTest(jtu.JaxTestCase):
class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
if nodelist is not None:
coordinator_address = nodelist.split('[')[0] + \
nodelist.split('[')[1].split(',')[0]
coordinator_address = os.environ.get("SLURM_STEP_NODELIST", None)
if coordinator_address is not None and '[' in coordinator_address:
coordinator_address = coordinator_address.split('[')[0] + \
coordinator_address.split('[')[1].split(',')[0]
num_tasks = os.environ.get("SLURM_NPROCS", None)
taskid = os.environ.get("SLURM_PROCID", None)
localid = os.environ.get("SLURM_LOCALID", None)
@ -174,6 +174,9 @@ class MultiNodeGpuTest(jtu.JaxTestCase):
coordinator_address is None or num_tasks is None or taskid is None,
False)
# os.environ["CUDA_VISIBLE_DEVICES"] = localid #WAR for Bug:12119
jax.config.update("jax_cuda_visible_devices", localid)
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
num_processes=int(num_tasks),
process_id=int(taskid))