[ROCm]: Run pmap test on specific number of GPUs

This commit is contained in:
Rahul Batra 2023-03-30 18:34:47 +00:00
parent 67a28ce30f
commit 13e45c8953

View File

@ -16,5 +16,17 @@
set -eux
# run test module with multi-gpu requirements. We currently do not have a way to filter tests.
# this issue is also tracked in https://github.com/google/jax/issues/7323
python3 -m pytest --reruns 3 -x tests/pmap_test.py
cmd=$(lspci|grep 'controller'|grep 'AMD/ATI'|wc -l)
echo $cmd
if [[ $cmd -gt 8 ]]; then
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
elif [[ $cmd -gt 4 ]]; then
export HIP_VISIBLE_DEVICES=0,1,2,3 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
elif [[ $cmd -gt 2 ]]; then
export HIP_VISIBLE_DEVICES=0,1 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
else
export HIP_VISIBLE_DEVICES=0 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
fi
python3 -m pytest --reruns 3 -x tests/multi_device_test.py