mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[ROCm]: Run pmap test on specific number of GPUs
This commit is contained in:
parent
67a28ce30f
commit
13e45c8953
@ -16,5 +16,17 @@
|
||||
set -eux
|
||||
# run test module with multi-gpu requirements. We currently do not have a way to filter tests.
|
||||
# this issue is also tracked in https://github.com/google/jax/issues/7323
|
||||
python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
cmd=$(lspci|grep 'controller'|grep 'AMD/ATI'|wc -l)
|
||||
echo $cmd
|
||||
|
||||
if [[ $cmd -gt 8 ]]; then
|
||||
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
elif [[ $cmd -gt 4 ]]; then
|
||||
export HIP_VISIBLE_DEVICES=0,1,2,3 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
elif [[ $cmd -gt 2 ]]; then
|
||||
export HIP_VISIBLE_DEVICES=0,1 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
else
|
||||
export HIP_VISIBLE_DEVICES=0 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
fi
|
||||
|
||||
python3 -m pytest --reruns 3 -x tests/multi_device_test.py
|
||||
|
Loading…
x
Reference in New Issue
Block a user