mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Enable one gpu per process in multinode GPU CI
This commit is contained in:
parent
b9d7e05eda
commit
a571db18db
@ -6,8 +6,7 @@
|
||||
#SBATCH -J "ci-jax-gpu" # job name
|
||||
#SBATCH --exclusive # exclusive node access
|
||||
#SBATCH --mem=0 # all mem avail
|
||||
#SBATCH --mail-type=FAIL # only send email on failure
|
||||
#SBATCH --ntasks-per-node=1 # 1 tasks per machine for now
|
||||
#SBATCH --mail-type=FAIL # only send email on failures
|
||||
#SBATCH --overcommit # Needed for pytorch
|
||||
|
||||
set -x
|
||||
@ -57,6 +56,7 @@ OUTFILE="${OUTPUT_DIR}/output-%j-%n.txt"
|
||||
# that the processes are launched together
|
||||
echo $setup_cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--ntasks-per-node=1 \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
--container-name=$CONTAINER_NAME \
|
||||
@ -70,6 +70,7 @@ wait
|
||||
# Run the actual pytest command
|
||||
echo $cmd
|
||||
srun -o $OUTFILE -e $OUTFILE \
|
||||
--ntasks-per-node=8 \
|
||||
--open-mode=append \
|
||||
--container-writable \
|
||||
--container-image="$CONTAINER" \
|
||||
|
@ -151,15 +151,15 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
@unittest.skipIf(
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
"Slurm environment with at least two nodes needed!")
|
||||
class MultiNodeGpuTest(jtu.JaxTestCase):
|
||||
class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gpu_multi_node_initialize_and_psum(self):
|
||||
|
||||
# Hookup the ENV vars expected to be set already in the SLURM environment
|
||||
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
|
||||
if nodelist is not None:
|
||||
coordinator_address = nodelist.split('[')[0] + \
|
||||
nodelist.split('[')[1].split(',')[0]
|
||||
coordinator_address = os.environ.get("SLURM_STEP_NODELIST", None)
|
||||
if coordinator_address is not None and '[' in coordinator_address:
|
||||
coordinator_address = coordinator_address.split('[')[0] + \
|
||||
coordinator_address.split('[')[1].split(',')[0]
|
||||
num_tasks = os.environ.get("SLURM_NPROCS", None)
|
||||
taskid = os.environ.get("SLURM_PROCID", None)
|
||||
localid = os.environ.get("SLURM_LOCALID", None)
|
||||
@ -174,6 +174,9 @@ class MultiNodeGpuTest(jtu.JaxTestCase):
|
||||
coordinator_address is None or num_tasks is None or taskid is None,
|
||||
False)
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = localid #WAR for Bug:12119
|
||||
jax.config.update("jax_cuda_visible_devices", localid)
|
||||
|
||||
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
|
||||
num_processes=int(num_tasks),
|
||||
process_id=int(taskid))
|
||||
|
Loading…
x
Reference in New Issue
Block a user