combine gpu tests

This commit is contained in:
Sudhakar 2022-08-25 15:27:07 -07:00
parent 1a365346e8
commit 4b1a2eaaec
4 changed files with 47 additions and 80 deletions

View File

@ -33,6 +33,7 @@ read -r -d '' setup_cmd <<EOF
python3.8 -m pip install --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html \
&& python3.8 -m pip install git+https://github.com/google/jax \
&& python3.8 -m pip install pytest \
&& python3.8 -m pip install pytest-forked \
&& mkdir -p /workspace/outputs/
EOF
@ -40,9 +41,9 @@ EOF
read -r -d '' cmd <<EOF
date \
&& python3.8 -m pip list | grep jax \
&& python3.8 -m pytest -v -s --continue-on-collection-errors \
&& python3.8 -m pytest --forked -v -s --continue-on-collection-errors \
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
/workspace/tests/distributed_multinode_test.py
/workspace/tests/multiprocess_gpu_test.py
EOF
# create run specific output directory for ease of analysis

View File

@ -90,17 +90,8 @@ py_test(
)
py_test(
name = "distributed_multinode_test",
srcs = ["distributed_multinode_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "distributed_test",
srcs = ["distributed_test.py"],
name = "multiprocess_gpu_test",
srcs = ["multiprocess_gpu_test.py"],
args = [
"--exclude_test_targets=MultiProcessGpuTest",
],

View File

@ -1,67 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from absl.testing import absltest
import jax
import jax._src.lib
from jax._src import test_util as jtu
import jax.numpy as jnp
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
class MultiNodeGpuTest(jtu.JaxTestCase):
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
if nodelist is not None:
coordinator_address = nodelist.split('[')[0] + \
nodelist.split('[')[1].split(',')[0]
num_tasks = os.environ.get("SLURM_NPROCS", None)
taskid = os.environ.get("SLURM_PROCID", None)
localid = os.environ.get("SLURM_LOCALID", None)
# fixing port since it needs to be the same for all the processes
port = "54321"
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
self.assertEqual(
coordinator_address is None or num_tasks is None or taskid is None,
False)
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
num_processes=int(num_tasks),
process_id=int(taskid))
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(),
int(num_tasks) * jax.local_device_count())
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import jax
from jax.config import config
from jax._src import distributed
import jax.numpy as jnp
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
@ -146,5 +147,46 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
class MultiNodeGpuTest(jtu.JaxTestCase):
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
if nodelist is not None:
coordinator_address = nodelist.split('[')[0] + \
nodelist.split('[')[1].split(',')[0]
num_tasks = os.environ.get("SLURM_NPROCS", None)
taskid = os.environ.get("SLURM_PROCID", None)
localid = os.environ.get("SLURM_LOCALID", None)
# fixing port since it needs to be the same for all the processes
port = "54321"
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
self.assertEqual(
coordinator_address is None or num_tasks is None or taskid is None,
False)
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
num_processes=int(num_tasks),
process_id=int(taskid))
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(),
int(num_tasks) * jax.local_device_count())
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())