Add GPU CI (#137)

This commit is contained in:
charleshofer 2025-01-07 15:01:50 -06:00 committed by GitHub
parent 972f95b95d
commit bc06c93d23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 297 additions and 173 deletions

63
.github/workflows/rocm-ci.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: ROCm GPU CI
on:
# Trigger the workflow on push or pull request,
# but only for the rocm-main branch
push:
branches:
- rocm-main
pull_request:
branches:
- rocm-main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-jax-in-docker: # strategy and matrix come here
runs-on: mi-250
env:
BASE_IMAGE: "ubuntu:22.04"
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
PYTHON_VERSION: "3.10"
ROCM_VERSION: "6.2.4"
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
steps:
- name: Clean up old runs
run: |
ls
# Make sure that we own all of the files so that we have permissions to delete them
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
# Remove any old work directories from this machine
rm -rf workdir_*
ls
- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: ${{ env.WORKSPACE_DIR }}
- name: Build JAX
run: |
pushd $WORKSPACE_DIR
python3 build/rocm/ci_build \
--rocm-version $ROCM_VERSION \
--base-docker $BASE_IMAGE \
--python-versions $PYTHON_VERSION \
--compiler=clang \
dist_docker \
--image-tag $TEST_IMAGE
- name: Archive jax wheels
uses: actions/upload-artifact@v4
with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ./dist/*.whl
- name: Run tests
run: |
cd $WORKSPACE_DIR
python3 build/rocm/ci_build test $TEST_IMAGE

View File

@ -5,7 +5,11 @@ ARG ROCM_BUILD_JOB
ARG ROCM_BUILD_NUM
# Install system GCC and C++ libraries.
RUN yum install -y gcc-c++.x86_64
# (charleshofer) This is not ideal, as we should already have GCC and C++ libraries in the
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
# to find them.
RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
@ -20,3 +24,6 @@ RUN --mount=type=cache,target=/var/cache/dnf \
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project
# Stop git from erroring out when we don't own the repo
RUN git config --global --add safe.directory '*'

View File

@ -21,11 +21,15 @@
import argparse
import logging
import os
import subprocess
import sys
LOG = logging.getLogger("ci_build")
def image_by_name(name):
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
out = subprocess.check_output(cmd)
@ -33,6 +37,25 @@ def image_by_name(name):
return image_id
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image_name,
".",
]
LOG.info("Creating manylinux build image. Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
return image_name
def dist_wheels(
rocm_version,
python_versions,
@ -41,34 +64,13 @@ def dist_wheels(
rocm_build_num="",
compiler="gcc",
):
# We want to make sure the wheels we build are manylinux compliant. We'll
# do the build in a container. Build the image for this.
image_name = create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num)
if xla_path:
xla_path = os.path.abspath(xla_path)
# create manylinux image with requested ROCm installed
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
# Try removing the Docker image.
try:
subprocess.run(["docker", "rmi", image], check=True)
print(f"Image {image} removed successfully.")
except subprocess.CalledProcessError as e:
print(f"Failed to remove Docker image {image}: {e}")
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image,
".",
]
if not image_by_name(image):
_ = subprocess.run(cmd, check=True)
# use image to build JAX/jaxlib wheels
os.makedirs("wheelhouse", exist_ok=True)
@ -114,13 +116,14 @@ def dist_wheels(
[
"--init",
"--rm",
image,
image_name,
"bash",
"-c",
" ".join(bw_cmd),
]
)
LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
@ -141,10 +144,16 @@ def _fetch_jax_metadata(xla_path):
jax_version = subprocess.check_output(cmd, env=env)
def safe_decode(x):
if isinstance(x, str):
return x
else:
return x.decode("utf8")
return {
"jax_version": jax_version.decode("utf8").strip(),
"jax_commit": jax_commit.decode("utf8").strip(),
"xla_commit": xla_commit.decode("utf8").strip(),
"jax_version": safe_decode(jax_version).strip(),
"jax_commit": safe_decode(jax_commit).strip(),
"xla_commit": safe_decode(xla_commit).strip(),
}
@ -211,10 +220,12 @@ def test(image_name):
cmd = [
"docker",
"run",
"-it",
"--rm",
]
if os.isatty(sys.stdout.fileno()):
cmd.append("-it")
# NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels
mounts = [
@ -298,6 +309,7 @@ def parse_args():
def main():
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.action == "dist_wheels":

View File

@ -25,179 +25,205 @@ GPU_LOCK = threading.Lock()
LAST_CODE = 0
base_dir = "./logs"
def extract_filename(path):
base_name = os.path.basename(path)
file_name, _ = os.path.splitext(base_name)
return file_name
base_name = os.path.basename(path)
file_name, _ = os.path.splitext(base_name)
return file_name
def combine_json_reports():
all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')]
combined_data = []
for json_file in all_json_files:
with open(os.path.join(base_dir, json_file), 'r') as infile:
data = json.load(infile)
combined_data.append(data)
combined_json_file = f"{base_dir}/final_compiled_report.json"
with open(combined_json_file, 'w') as outfile:
json.dump(combined_data, outfile, indent=4)
all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
combined_data = []
for json_file in all_json_files:
with open(os.path.join(base_dir, json_file), "r") as infile:
data = json.load(infile)
combined_data.append(data)
combined_json_file = f"{base_dir}/final_compiled_report.json"
with open(combined_json_file, "w") as outfile:
json.dump(combined_data, outfile, indent=4)
def combine_csv_reports():
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')]
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
with open(combined_csv_file, mode='w', newline='') as outfile:
csv_writer = csv.writer(outfile)
for i, csv_file in enumerate(all_csv_files):
with open(os.path.join(base_dir, csv_file), mode='r') as infile:
csv_reader = csv.reader(infile)
if i == 0:
# write headers only once
csv_writer.writerow(next(csv_reader))
for row in csv_reader:
csv_writer.writerow(row)
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
with open(combined_csv_file, mode="w", newline="") as outfile:
csv_writer = csv.writer(outfile)
for i, csv_file in enumerate(all_csv_files):
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
csv_reader = csv.reader(infile)
if i == 0:
# write headers only once
csv_writer.writerow(next(csv_reader))
for row in csv_reader:
csv_writer.writerow(row)
def generate_final_report(shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
result = subprocess.run(cmd,
shell=shell,
capture_output=True,
env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
env = os.environ
env = {**env, **env_vars}
cmd = [
"pytest_html_merger",
"-i",
f"{base_dir}",
"-o",
f"{base_dir}/final_compiled_report.html",
]
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
# Generate json reports.
combine_json_reports()
# Generate csv reports.
combine_csv_reports()
# Generate json reports.
combine_json_reports()
# Generate csv reports.
combine_csv_reports()
def run_shell_command(cmd, shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
result = subprocess.run(cmd,
shell=shell,
capture_output=True,
env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
env = os.environ
env = {**env, **env_vars}
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
return result.returncode, result.stderr.decode(), result.stdout.decode()
return result.returncode, result.stderr.decode(), result.stdout.decode()
def parse_test_log(log_file):
"""Parses the test module log file to extract test modules and functions."""
test_files = set()
with open(log_file, "r") as f:
for line in f:
report = json.loads(line)
if "nodeid" in report:
module = report["nodeid"].split("::")[0]
if module and ".py" in module:
test_files.add(os.path.abspath(module))
return test_files
"""Parses the test module log file to extract test modules and functions."""
test_files = set()
with open(log_file, "r") as f:
for line in f:
report = json.loads(line)
if "nodeid" in report:
module = report["nodeid"].split("::")[0]
if module and ".py" in module:
test_files.add(os.path.abspath(module))
return test_files
def collect_testmodules():
log_file = f"{base_dir}/collect_module_log.jsonl"
return_code, stderr, stdout = run_shell_command(
["python3", "-m", "pytest", "--collect-only", "tests", f"--report-log={log_file}"])
if return_code != 0:
print("Test module discovery failed.")
print("STDOUT:", stdout)
print("STDERR:", stderr)
exit(return_code)
print("---------- collected test modules ----------")
test_files = parse_test_log(log_file)
print("Found %d test modules." % (len(test_files)))
print("--------------------------------------------")
print("\n".join(test_files))
return test_files
log_file = f"{base_dir}/collect_module_log.jsonl"
return_code, stderr, stdout = run_shell_command(
[
"python3",
"-m",
"pytest",
"--collect-only",
"tests",
f"--report-log={log_file}",
]
)
if return_code != 0:
print("Test module discovery failed.")
print("STDOUT:", stdout)
print("STDERR:", stderr)
exit(return_code)
print("---------- collected test modules ----------")
test_files = parse_test_log(log_file)
print("Found %d test modules." % (len(test_files)))
print("--------------------------------------------")
print("\n".join(test_files))
return test_files
def run_test(testmodule, gpu_tokens, continue_on_fail):
global LAST_CODE
with GPU_LOCK:
if LAST_CODE != 0:
return
target_gpu = gpu_tokens.pop()
env_vars = {
"HIP_VISIBLE_DEVICES": str(target_gpu),
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
testfile = extract_filename(testmodule)
if continue_on_fail:
cmd = ["python3", "-m", "pytest",
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns", "id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns", "3", "-v", testmodule]
else:
cmd = ["python3", "-m", "pytest",
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns", "id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns", "3", "-x", "-v", testmodule]
global LAST_CODE
with GPU_LOCK:
if LAST_CODE != 0:
return
target_gpu = gpu_tokens.pop()
env_vars = {
"HIP_VISIBLE_DEVICES": str(target_gpu),
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
testfile = extract_filename(testmodule)
if continue_on_fail:
cmd = [
"python3",
"-m",
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",
"-v",
testmodule,
]
else:
cmd = [
"python3",
"-m",
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",
"-x",
"-v",
testmodule,
]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)
if LAST_CODE == 0:
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
print(stdout)
print(stderr)
if continue_on_fail == False:
LAST_CODE = return_code
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)
if LAST_CODE == 0:
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
print(stdout)
print(stderr)
if continue_on_fail == False:
LAST_CODE = return_code
def run_parallel(all_testmodules, p, c):
print(f"Running tests with parallelism = {p}")
available_gpu_tokens = list(range(p))
executor = ThreadPoolExecutor(max_workers=p)
# walking through test modules.
for testmodule in all_testmodules:
executor.submit(run_test, testmodule, available_gpu_tokens, c)
# waiting for all modules to finish.
executor.shutdown(wait=True)
print(f"Running tests with parallelism = {p}")
available_gpu_tokens = list(range(p))
executor = ThreadPoolExecutor(max_workers=p)
# walking through test modules.
for testmodule in all_testmodules:
executor.submit(run_test, testmodule, available_gpu_tokens, c)
# waiting for all modules to finish.
executor.shutdown(wait=True)
def find_num_gpus():
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
_, _, stdout = run_shell_command(cmd, shell=True)
return int(stdout)
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
_, _, stdout = run_shell_command(cmd, shell=True)
return int(stdout)
def main(args):
all_testmodules = collect_testmodules()
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
generate_final_report()
exit(LAST_CODE)
all_testmodules = collect_testmodules()
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
generate_final_report()
exit(LAST_CODE)
if __name__ == '__main__':
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so"
parser = argparse.ArgumentParser()
parser.add_argument("-p",
"--parallel",
type=int,
help="number of tests to run in parallel")
parser.add_argument("-c",
"--continue_on_fail",
action='store_true',
help="continue on failure")
args = parser.parse_args()
if args.continue_on_fail:
print("continue on fail is set")
if args.parallel is None:
sys_gpu_count = find_num_gpus()
args.parallel = sys_gpu_count
print("%d GPUs detected." % sys_gpu_count)
if __name__ == "__main__":
os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--parallel", type=int, help="number of tests to run in parallel"
)
parser.add_argument(
"-c", "--continue_on_fail", action="store_true", help="continue on failure"
)
args = parser.parse_args()
if args.continue_on_fail:
print("continue on fail is set")
if args.parallel is None:
sys_gpu_count = find_num_gpus()
args.parallel = sys_gpu_count
print("%d GPUs detected." % sys_gpu_count)
main(args)
main(args)

View File

@ -116,6 +116,7 @@ def build_jaxlib_wheel(
if compiler == "clang":
clang_path = find_clang_path()
if clang_path:
LOG.info("Found clang at path: %s", clang_path)
cmd.append("--clang_path=%s" % clang_path)
else:
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
@ -315,6 +316,21 @@ def main():
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir)
# Delete the 'dist' directory since it causes permissions issues
logging.info('Deleting dist, egg-info and cache directory')
shutil.rmtree(os.path.join(args.jax_path, "dist"))
shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info"))
shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__"))
# Make the wheels deleteable by the runner
whl_house = os.path.join(args.jax_path, "wheelhouse")
logging.info("Changing permissions for %s" % whl_house)
mode = 0o664
for item in os.listdir(whl_house):
whl_path = os.path.join(whl_house, item)
if os.path.isfile(whl_path):
os.chmod(whl_path, mode)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

View File

@ -21,15 +21,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "1a6361a734c5cd10dc93938fc6163a51fd37b82e"
XLA_SHA256 = "01159fd52f0e402829a3823472a309562817c72d0212f81cd5555f77394c094f"
XLA_COMMIT = "373f359cbd8d02ee850d98fed92a7bbca4a09c1b"
XLA_SHA256 = "bccda939edabf6723fcb9e59b833288d66ff93b6f34902c28c521a0b39b52d83"
def repo():
tf_http_archive(
name = "xla",
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
urls = tf_mirror_urls("https://github.com/rocm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
)
# For development, one often wants to make changes to the TF repository as well