Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix

PiperOrigin-RevId: 520934992
This commit is contained in:
jax authors 2023-03-31 09:05:07 -07:00
commit dfbbc2551c

View File

@ -16,5 +16,17 @@
set -eux
# run test module with multi-gpu requirements. We currently do not have a way to filter tests.
# this issue is also tracked in https://github.com/google/jax/issues/7323
python3 -m pytest --reruns 3 -x tests/pmap_test.py
cmd=$(lspci|grep 'controller'|grep 'AMD/ATI'|wc -l)
echo $cmd
if [[ $cmd -gt 8 ]]; then
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
elif [[ $cmd -gt 4 ]]; then
export HIP_VISIBLE_DEVICES=0,1,2,3 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
elif [[ $cmd -gt 2 ]]; then
export HIP_VISIBLE_DEVICES=0,1 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
else
export HIP_VISIBLE_DEVICES=0 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
fi
python3 -m pytest --reruns 3 -x tests/multi_device_test.py