rocm_jax/build/rocm/run_single_gpu.py

205 lines
6.2 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import argparse
import threading
import subprocess
from concurrent.futures import ThreadPoolExecutor
GPU_LOCK = threading.Lock()
LAST_CODE = 0
base_dir = "./logs"
def extract_filename(path):
base_name = os.path.basename(path)
file_name, _ = os.path.splitext(base_name)
return file_name
def combine_json_reports():
all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
combined_data = []
for json_file in all_json_files:
with open(os.path.join(base_dir, json_file), "r") as infile:
data = json.load(infile)
combined_data.append(data)
combined_json_file = f"{base_dir}/final_compiled_report.json"
with open(combined_json_file, "w") as outfile:
json.dump(combined_data, outfile, indent=4)
def generate_final_report(shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
cmd = [
"pytest_html_merger",
"-i",
f"{base_dir}",
"-o",
f"{base_dir}/final_compiled_report.html",
]
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
# Generate json reports.
combine_json_reports()
def run_shell_command(cmd, shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
return result.returncode, result.stderr.decode(), result.stdout.decode()
def parse_test_log(log_file):
"""Parses the test module log file to extract test modules and functions."""
test_files = set()
with open(log_file, "r") as f:
for line in f:
report = json.loads(line)
if "nodeid" in report:
module = report["nodeid"].split("::")[0]
if module and ".py" in module:
test_files.add(os.path.abspath(module))
return test_files
def collect_testmodules():
log_file = f"{base_dir}/collect_module_log.jsonl"
return_code, stderr, stdout = run_shell_command(
[
"python3",
"-m",
"pytest",
"--collect-only",
"tests",
f"--report-log={log_file}",
]
)
if return_code != 0:
print("Test module discovery failed.")
print("STDOUT:", stdout)
print("STDERR:", stderr)
exit(return_code)
print("---------- collected test modules ----------")
test_files = parse_test_log(log_file)
print("Found %d test modules." % (len(test_files)))
print("--------------------------------------------")
print("\n".join(test_files))
return test_files
def run_test(testmodule, gpu_tokens, continue_on_fail):
global LAST_CODE
with GPU_LOCK:
if LAST_CODE != 0:
return
target_gpu = gpu_tokens.pop()
env_vars = {
"HIP_VISIBLE_DEVICES": str(target_gpu),
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
testfile = extract_filename(testmodule)
if continue_on_fail:
cmd = [
"python3",
"-m",
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",
"-v",
testmodule,
]
else:
cmd = [
"python3",
"-m",
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",
"-x",
"-v",
testmodule,
]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)
if LAST_CODE == 0:
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
print(stdout)
print(stderr)
if continue_on_fail == False:
LAST_CODE = return_code
def run_parallel(all_testmodules, p, c):
print(f"Running tests with parallelism = {p}")
available_gpu_tokens = list(range(p))
executor = ThreadPoolExecutor(max_workers=p)
# walking through test modules.
for testmodule in all_testmodules:
executor.submit(run_test, testmodule, available_gpu_tokens, c)
# waiting for all modules to finish.
executor.shutdown(wait=True)
def find_num_gpus():
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
_, _, stdout = run_shell_command(cmd, shell=True)
return int(stdout)
def main(args):
all_testmodules = collect_testmodules()
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
generate_final_report()
exit(LAST_CODE)
if __name__ == "__main__":
os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--parallel", type=int, help="number of tests to run in parallel"
)
parser.add_argument(
"-c", "--continue_on_fail", action="store_true", help="continue on failure"
)
args = parser.parse_args()
if args.continue_on_fail:
print("continue on fail is set")
if args.parallel is None:
sys_gpu_count = find_num_gpus()
args.parallel = sys_gpu_count
print("%d GPUs detected." % sys_gpu_count)
main(args)