mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
205 lines
6.2 KiB
Python
Executable File
205 lines
6.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import json
|
|
import argparse
|
|
import threading
|
|
import subprocess
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
GPU_LOCK = threading.Lock()
|
|
LAST_CODE = 0
|
|
base_dir = "./logs"
|
|
|
|
|
|
def extract_filename(path):
|
|
base_name = os.path.basename(path)
|
|
file_name, _ = os.path.splitext(base_name)
|
|
return file_name
|
|
|
|
|
|
def combine_json_reports():
|
|
all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
|
|
combined_data = []
|
|
for json_file in all_json_files:
|
|
with open(os.path.join(base_dir, json_file), "r") as infile:
|
|
data = json.load(infile)
|
|
combined_data.append(data)
|
|
combined_json_file = f"{base_dir}/final_compiled_report.json"
|
|
with open(combined_json_file, "w") as outfile:
|
|
json.dump(combined_data, outfile, indent=4)
|
|
|
|
def generate_final_report(shell=False, env_vars={}):
|
|
env = os.environ
|
|
env = {**env, **env_vars}
|
|
cmd = [
|
|
"pytest_html_merger",
|
|
"-i",
|
|
f"{base_dir}",
|
|
"-o",
|
|
f"{base_dir}/final_compiled_report.html",
|
|
]
|
|
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
|
|
if result.returncode != 0:
|
|
print("FAILED - {}".format(" ".join(cmd)))
|
|
print(result.stderr.decode())
|
|
|
|
# Generate json reports.
|
|
combine_json_reports()
|
|
|
|
|
|
def run_shell_command(cmd, shell=False, env_vars={}):
|
|
env = os.environ
|
|
env = {**env, **env_vars}
|
|
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
|
|
if result.returncode != 0:
|
|
print("FAILED - {}".format(" ".join(cmd)))
|
|
print(result.stderr.decode())
|
|
|
|
return result.returncode, result.stderr.decode(), result.stdout.decode()
|
|
|
|
|
|
def parse_test_log(log_file):
|
|
"""Parses the test module log file to extract test modules and functions."""
|
|
test_files = set()
|
|
with open(log_file, "r") as f:
|
|
for line in f:
|
|
report = json.loads(line)
|
|
if "nodeid" in report:
|
|
module = report["nodeid"].split("::")[0]
|
|
if module and ".py" in module:
|
|
test_files.add(os.path.abspath(module))
|
|
return test_files
|
|
|
|
|
|
def collect_testmodules():
|
|
log_file = f"{base_dir}/collect_module_log.jsonl"
|
|
return_code, stderr, stdout = run_shell_command(
|
|
[
|
|
"python3",
|
|
"-m",
|
|
"pytest",
|
|
"--collect-only",
|
|
"tests",
|
|
f"--report-log={log_file}",
|
|
]
|
|
)
|
|
if return_code != 0:
|
|
print("Test module discovery failed.")
|
|
print("STDOUT:", stdout)
|
|
print("STDERR:", stderr)
|
|
exit(return_code)
|
|
print("---------- collected test modules ----------")
|
|
test_files = parse_test_log(log_file)
|
|
print("Found %d test modules." % (len(test_files)))
|
|
print("--------------------------------------------")
|
|
print("\n".join(test_files))
|
|
return test_files
|
|
|
|
|
|
def run_test(testmodule, gpu_tokens, continue_on_fail):
|
|
global LAST_CODE
|
|
with GPU_LOCK:
|
|
if LAST_CODE != 0:
|
|
return
|
|
target_gpu = gpu_tokens.pop()
|
|
env_vars = {
|
|
"HIP_VISIBLE_DEVICES": str(target_gpu),
|
|
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
|
|
}
|
|
testfile = extract_filename(testmodule)
|
|
if continue_on_fail:
|
|
cmd = [
|
|
"python3",
|
|
"-m",
|
|
"pytest",
|
|
"--json-report",
|
|
f"--json-report-file={base_dir}/{testfile}_log.json",
|
|
f"--html={base_dir}/{testfile}_log.html",
|
|
"--reruns",
|
|
"3",
|
|
"-v",
|
|
testmodule,
|
|
]
|
|
else:
|
|
cmd = [
|
|
"python3",
|
|
"-m",
|
|
"pytest",
|
|
"--json-report",
|
|
f"--json-report-file={base_dir}/{testfile}_log.json",
|
|
f"--html={base_dir}/{testfile}_log.html",
|
|
"--reruns",
|
|
"3",
|
|
"-x",
|
|
"-v",
|
|
testmodule,
|
|
]
|
|
|
|
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
|
|
with GPU_LOCK:
|
|
gpu_tokens.append(target_gpu)
|
|
if LAST_CODE == 0:
|
|
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
|
|
print(stdout)
|
|
print(stderr)
|
|
if continue_on_fail == False:
|
|
LAST_CODE = return_code
|
|
|
|
|
|
def run_parallel(all_testmodules, p, c):
|
|
print(f"Running tests with parallelism = {p}")
|
|
available_gpu_tokens = list(range(p))
|
|
executor = ThreadPoolExecutor(max_workers=p)
|
|
# walking through test modules.
|
|
for testmodule in all_testmodules:
|
|
executor.submit(run_test, testmodule, available_gpu_tokens, c)
|
|
# waiting for all modules to finish.
|
|
executor.shutdown(wait=True)
|
|
|
|
|
|
def find_num_gpus():
|
|
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
|
|
_, _, stdout = run_shell_command(cmd, shell=True)
|
|
return int(stdout)
|
|
|
|
|
|
def main(args):
|
|
all_testmodules = collect_testmodules()
|
|
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
|
|
generate_final_report()
|
|
exit(LAST_CODE)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-p", "--parallel", type=int, help="number of tests to run in parallel"
|
|
)
|
|
parser.add_argument(
|
|
"-c", "--continue_on_fail", action="store_true", help="continue on failure"
|
|
)
|
|
args = parser.parse_args()
|
|
if args.continue_on_fail:
|
|
print("continue on fail is set")
|
|
if args.parallel is None:
|
|
sys_gpu_count = find_num_gpus()
|
|
args.parallel = sys_gpu_count
|
|
print("%d GPUs detected." % sys_gpu_count)
|
|
|
|
main(args)
|