mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix
PiperOrigin-RevId: 520934992
This commit is contained in:
commit
dfbbc2551c
@ -16,5 +16,17 @@
|
||||
set -eux
|
||||
# run test module with multi-gpu requirements. We currently do not have a way to filter tests.
|
||||
# this issue is also tracked in https://github.com/google/jax/issues/7323
|
||||
python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
cmd=$(lspci|grep 'controller'|grep 'AMD/ATI'|wc -l)
|
||||
echo $cmd
|
||||
|
||||
if [[ $cmd -gt 8 ]]; then
|
||||
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
elif [[ $cmd -gt 4 ]]; then
|
||||
export HIP_VISIBLE_DEVICES=0,1,2,3 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
elif [[ $cmd -gt 2 ]]; then
|
||||
export HIP_VISIBLE_DEVICES=0,1 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
else
|
||||
export HIP_VISIBLE_DEVICES=0 && python3 -m pytest --reruns 3 -x tests/pmap_test.py
|
||||
fi
|
||||
|
||||
python3 -m pytest --reruns 3 -x tests/multi_device_test.py
|
||||
|
Loading…
x
Reference in New Issue
Block a user