diff --git a/ci/envs/default.env b/ci/envs/default.env index e2bcfc26b..72646113e 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -53,13 +53,6 @@ export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} # ############################################################################# # Test script specific environment variables. # ############################################################################# -# The maximum number of tests to run per GPU when running single accelerator -# tests with parallel execution with Bazel. The GPU limit is set because we -# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we -# use L4 machines which have 24GB of RAM but can be overriden if we use a -# different GPU type. -export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} - # Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override # this value in the Github action workflow files. export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} diff --git a/ci/run_bazel_test_cuda_non_rbe.sh b/ci/run_bazel_test_cuda_non_rbe.sh index 8c0880353..176efd344 100755 --- a/ci/run_bazel_test_cuda_non_rbe.sh +++ b/ci/run_bazel_test_cuda_non_rbe.sh @@ -37,14 +37,30 @@ nvidia-smi echo "Running single accelerator tests (without RBE)..." # Set up test environment variables. +# Set the number of test jobs to min(num_cpu_cores, gpu_count * max_tests_per_gpu, total_ram_gb / 6) +# We calculate max_tests_per_gpu as memory_per_gpu_gb / 2gb +# Calculate gpu_count * max_tests_per_gpu export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) -export num_test_jobs=$((gpu_count * JAXCI_MAX_TESTS_PER_GPU)) +export memory_per_gpu_gb=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits --id=0) +export memory_per_gpu_gb=$((memory_per_gpu_gb / 1024)) +# Allow 2 GB of GPU RAM per test +export max_tests_per_gpu=$((memory_per_gpu_gb / 2)) +export num_test_jobs=$((gpu_count * max_tests_per_gpu)) + +# Calculate num_cpu_cores export num_cpu_cores=$(nproc) -# tests_jobs = max(gpu_count * max_tests_per_gpu, num_cpu_cores) -if [[ $num_test_jobs -gt $num_cpu_cores ]]; then +# Calculate total_ram_gb / 6 +export total_ram_gb=$(awk '/MemTotal/ {printf "%.0f", $2/1048576}' /proc/meminfo) +export host_memory_limit=$((total_ram_gb / 6)) + +if [[ $num_cpu_cores -lt $num_test_jobs ]]; then num_test_jobs=$num_cpu_cores fi + +if [[ $host_memory_limit -lt $num_test_jobs ]]; then + num_test_jobs=$host_memory_limit +fi # End of test environment variables setup. # Don't abort the script if one command fails to ensure we run both test @@ -64,7 +80,7 @@ bazel test --config=ci_linux_x86_64_cuda \ --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ --test_output=errors \ --test_env=JAX_ACCELERATOR_COUNT=$gpu_count \ - --test_env=JAX_TESTS_PER_ACCELERATOR=$JAXCI_MAX_TESTS_PER_GPU \ + --test_env=JAX_TESTS_PER_ACCELERATOR=$max_tests_per_gpu \ --local_test_jobs=$num_test_jobs \ --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ --test_tag_filters=-multiaccelerator \ diff --git a/ci/run_pytest_cuda.sh b/ci/run_pytest_cuda.sh index eb815de14..d98068385 100755 --- a/ci/run_pytest_cuda.sh +++ b/ci/run_pytest_cuda.sh @@ -45,9 +45,30 @@ export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 export JAX_ENABLE_64="$JAXCI_ENABLE_X64" -# Set the number of processes to run to be 4x the number of GPUs. +# Set the number of processes to min(num_cpu_cores, gpu_count * $max_tests_per_gpu, total_ram_gb / 6) +# We calculate max_tests_per_gpu as memory_per_gpu_gb / 2gb +# Calculate gpu_count * max_tests_per_gpu export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) -export num_processes=`expr 4 \* $gpu_count` +export memory_per_gpu_gb=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits --id=0) +export memory_per_gpu_gb=$((memory_per_gpu_gb / 1024)) +# Allow 2 GB of GPU RAM per test +export max_tests_per_gpu=$((memory_per_gpu_gb / 2)) +export num_processes=$((gpu_count * max_tests_per_gpu)) + +# Calculate num_cpu_cores +export num_cpu_cores=$(nproc) + +# Calculate total_ram_gb / 6 +export total_ram_gb=$(awk '/MemTotal/ {printf "%.0f", $2/1048576}' /proc/meminfo) +export host_memory_limit=$((total_ram_gb / 6)) + +if [[ $num_cpu_cores -lt $num_processes ]]; then + num_processes=$num_cpu_cores +fi + +if [[ $host_memory_limit -lt $num_processes ]]; then + num_processes=$host_memory_limit +fi export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1