mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
combine gpu tests
This commit is contained in:
parent
1a365346e8
commit
4b1a2eaaec
@ -33,6 +33,7 @@ read -r -d '' setup_cmd <<EOF
|
||||
python3.8 -m pip install --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html \
|
||||
&& python3.8 -m pip install git+https://github.com/google/jax \
|
||||
&& python3.8 -m pip install pytest \
|
||||
&& python3.8 -m pip install pytest-forked \
|
||||
&& mkdir -p /workspace/outputs/
|
||||
EOF
|
||||
|
||||
@ -40,9 +41,9 @@ EOF
|
||||
read -r -d '' cmd <<EOF
|
||||
date \
|
||||
&& python3.8 -m pip list | grep jax \
|
||||
&& python3.8 -m pytest -v -s --continue-on-collection-errors \
|
||||
&& python3.8 -m pytest --forked -v -s --continue-on-collection-errors \
|
||||
--junit-xml=/workspace/outputs/junit_output_\${SLURM_PROCID}.xml \
|
||||
/workspace/tests/distributed_multinode_test.py
|
||||
/workspace/tests/multiprocess_gpu_test.py
|
||||
EOF
|
||||
|
||||
# create run specific output directory for ease of analysis
|
||||
|
13
tests/BUILD
13
tests/BUILD
@ -90,17 +90,8 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "distributed_multinode_test",
|
||||
srcs = ["distributed_multinode_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "distributed_test",
|
||||
srcs = ["distributed_test.py"],
|
||||
name = "multiprocess_gpu_test",
|
||||
srcs = ["multiprocess_gpu_test.py"],
|
||||
args = [
|
||||
"--exclude_test_targets=MultiProcessGpuTest",
|
||||
],
|
||||
|
@ -1,67 +0,0 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax._src.lib
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
"Slurm environment with at least two nodes needed!")
|
||||
class MultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gpu_multi_node_initialize_and_psum(self):
|
||||
|
||||
# Hookup the ENV vars expected to be set already in the SLURM environment
|
||||
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
|
||||
if nodelist is not None:
|
||||
coordinator_address = nodelist.split('[')[0] + \
|
||||
nodelist.split('[')[1].split(',')[0]
|
||||
num_tasks = os.environ.get("SLURM_NPROCS", None)
|
||||
taskid = os.environ.get("SLURM_PROCID", None)
|
||||
localid = os.environ.get("SLURM_LOCALID", None)
|
||||
|
||||
# fixing port since it needs to be the same for all the processes
|
||||
port = "54321"
|
||||
|
||||
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
|
||||
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
|
||||
|
||||
self.assertEqual(
|
||||
coordinator_address is None or num_tasks is None or taskid is None,
|
||||
False)
|
||||
|
||||
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
|
||||
num_processes=int(num_tasks),
|
||||
process_id=int(taskid))
|
||||
|
||||
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
|
||||
f"Devices per task: {jax.local_device_count()}")
|
||||
|
||||
self.assertEqual(jax.device_count(),
|
||||
int(num_tasks) * jax.local_device_count())
|
||||
|
||||
x = jnp.ones(jax.local_device_count())
|
||||
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax._src import distributed
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
@ -146,5 +147,46 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
self.assertEqual(proc.returncode, 0)
|
||||
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
"Slurm environment with at least two nodes needed!")
|
||||
class MultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gpu_multi_node_initialize_and_psum(self):
|
||||
|
||||
# Hookup the ENV vars expected to be set already in the SLURM environment
|
||||
nodelist = os.environ.get("SLURM_STEP_NODELIST", None)
|
||||
if nodelist is not None:
|
||||
coordinator_address = nodelist.split('[')[0] + \
|
||||
nodelist.split('[')[1].split(',')[0]
|
||||
num_tasks = os.environ.get("SLURM_NPROCS", None)
|
||||
taskid = os.environ.get("SLURM_PROCID", None)
|
||||
localid = os.environ.get("SLURM_LOCALID", None)
|
||||
|
||||
# fixing port since it needs to be the same for all the processes
|
||||
port = "54321"
|
||||
|
||||
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
|
||||
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
|
||||
|
||||
self.assertEqual(
|
||||
coordinator_address is None or num_tasks is None or taskid is None,
|
||||
False)
|
||||
|
||||
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
|
||||
num_processes=int(num_tasks),
|
||||
process_id=int(taskid))
|
||||
|
||||
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
|
||||
f"Devices per task: {jax.local_device_count()}")
|
||||
|
||||
self.assertEqual(jax.device_count(),
|
||||
int(num_tasks) * jax.local_device_count())
|
||||
|
||||
x = jnp.ones(jax.local_device_count())
|
||||
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user