mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add GPU CI (#137)
This commit is contained in:
parent
972f95b95d
commit
bc06c93d23
63
.github/workflows/rocm-ci.yml
vendored
Normal file
63
.github/workflows/rocm-ci.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: ROCm GPU CI
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the rocm-main branch
|
||||
push:
|
||||
branches:
|
||||
- rocm-main
|
||||
pull_request:
|
||||
branches:
|
||||
- rocm-main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-jax-in-docker: # strategy and matrix come here
|
||||
runs-on: mi-250
|
||||
env:
|
||||
BASE_IMAGE: "ubuntu:22.04"
|
||||
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
PYTHON_VERSION: "3.10"
|
||||
ROCM_VERSION: "6.2.4"
|
||||
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
steps:
|
||||
- name: Clean up old runs
|
||||
run: |
|
||||
ls
|
||||
# Make sure that we own all of the files so that we have permissions to delete them
|
||||
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
|
||||
# Remove any old work directories from this machine
|
||||
rm -rf workdir_*
|
||||
ls
|
||||
- name: Print system info
|
||||
run: |
|
||||
whoami
|
||||
printenv
|
||||
df -h
|
||||
rocm-smi
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
path: ${{ env.WORKSPACE_DIR }}
|
||||
- name: Build JAX
|
||||
run: |
|
||||
pushd $WORKSPACE_DIR
|
||||
python3 build/rocm/ci_build \
|
||||
--rocm-version $ROCM_VERSION \
|
||||
--base-docker $BASE_IMAGE \
|
||||
--python-versions $PYTHON_VERSION \
|
||||
--compiler=clang \
|
||||
dist_docker \
|
||||
--image-tag $TEST_IMAGE
|
||||
- name: Archive jax wheels
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
|
||||
path: ./dist/*.whl
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd $WORKSPACE_DIR
|
||||
python3 build/rocm/ci_build test $TEST_IMAGE
|
||||
|
@ -5,7 +5,11 @@ ARG ROCM_BUILD_JOB
|
||||
ARG ROCM_BUILD_NUM
|
||||
|
||||
# Install system GCC and C++ libraries.
|
||||
RUN yum install -y gcc-c++.x86_64
|
||||
# (charleshofer) This is not ideal, as we should already have GCC and C++ libraries in the
|
||||
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
|
||||
# to find them.
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
@ -20,3 +24,6 @@ RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
|
||||
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
|
||||
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project
|
||||
|
||||
# Stop git from erroring out when we don't own the repo
|
||||
RUN git config --global --add safe.directory '*'
|
||||
|
@ -21,11 +21,15 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
LOG = logging.getLogger("ci_build")
|
||||
|
||||
|
||||
def image_by_name(name):
|
||||
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
|
||||
out = subprocess.check_output(cmd)
|
||||
@ -33,6 +37,25 @@ def image_by_name(name):
|
||||
return image_id
|
||||
|
||||
|
||||
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
|
||||
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
||||
cmd = [
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
||||
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
||||
"--tag=%s" % image_name,
|
||||
".",
|
||||
]
|
||||
|
||||
LOG.info("Creating manylinux build image. Running: %s", cmd)
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
return image_name
|
||||
|
||||
|
||||
def dist_wheels(
|
||||
rocm_version,
|
||||
python_versions,
|
||||
@ -41,34 +64,13 @@ def dist_wheels(
|
||||
rocm_build_num="",
|
||||
compiler="gcc",
|
||||
):
|
||||
# We want to make sure the wheels we build are manylinux compliant. We'll
|
||||
# do the build in a container. Build the image for this.
|
||||
image_name = create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num)
|
||||
|
||||
if xla_path:
|
||||
xla_path = os.path.abspath(xla_path)
|
||||
|
||||
# create manylinux image with requested ROCm installed
|
||||
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
||||
|
||||
# Try removing the Docker image.
|
||||
try:
|
||||
subprocess.run(["docker", "rmi", image], check=True)
|
||||
print(f"Image {image} removed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to remove Docker image {image}: {e}")
|
||||
|
||||
cmd = [
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
||||
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
||||
"--tag=%s" % image,
|
||||
".",
|
||||
]
|
||||
|
||||
if not image_by_name(image):
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
# use image to build JAX/jaxlib wheels
|
||||
os.makedirs("wheelhouse", exist_ok=True)
|
||||
|
||||
@ -114,13 +116,14 @@ def dist_wheels(
|
||||
[
|
||||
"--init",
|
||||
"--rm",
|
||||
image,
|
||||
image_name,
|
||||
"bash",
|
||||
"-c",
|
||||
" ".join(bw_cmd),
|
||||
]
|
||||
)
|
||||
|
||||
LOG.info("Running: %s", cmd)
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
@ -141,10 +144,16 @@ def _fetch_jax_metadata(xla_path):
|
||||
|
||||
jax_version = subprocess.check_output(cmd, env=env)
|
||||
|
||||
def safe_decode(x):
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
else:
|
||||
return x.decode("utf8")
|
||||
|
||||
return {
|
||||
"jax_version": jax_version.decode("utf8").strip(),
|
||||
"jax_commit": jax_commit.decode("utf8").strip(),
|
||||
"xla_commit": xla_commit.decode("utf8").strip(),
|
||||
"jax_version": safe_decode(jax_version).strip(),
|
||||
"jax_commit": safe_decode(jax_commit).strip(),
|
||||
"xla_commit": safe_decode(xla_commit).strip(),
|
||||
}
|
||||
|
||||
|
||||
@ -211,10 +220,12 @@ def test(image_name):
|
||||
cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"-it",
|
||||
"--rm",
|
||||
]
|
||||
|
||||
if os.isatty(sys.stdout.fileno()):
|
||||
cmd.append("-it")
|
||||
|
||||
# NOTE(mrodden): we need jax source dir for the unit test code only,
|
||||
# JAX and jaxlib are already installed from wheels
|
||||
mounts = [
|
||||
@ -298,6 +309,7 @@ def parse_args():
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = parse_args()
|
||||
|
||||
if args.action == "dist_wheels":
|
||||
|
@ -25,179 +25,205 @@ GPU_LOCK = threading.Lock()
|
||||
LAST_CODE = 0
|
||||
base_dir = "./logs"
|
||||
|
||||
|
||||
def extract_filename(path):
|
||||
base_name = os.path.basename(path)
|
||||
file_name, _ = os.path.splitext(base_name)
|
||||
return file_name
|
||||
base_name = os.path.basename(path)
|
||||
file_name, _ = os.path.splitext(base_name)
|
||||
return file_name
|
||||
|
||||
|
||||
def combine_json_reports():
|
||||
all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')]
|
||||
combined_data = []
|
||||
for json_file in all_json_files:
|
||||
with open(os.path.join(base_dir, json_file), 'r') as infile:
|
||||
data = json.load(infile)
|
||||
combined_data.append(data)
|
||||
combined_json_file = f"{base_dir}/final_compiled_report.json"
|
||||
with open(combined_json_file, 'w') as outfile:
|
||||
json.dump(combined_data, outfile, indent=4)
|
||||
all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
|
||||
combined_data = []
|
||||
for json_file in all_json_files:
|
||||
with open(os.path.join(base_dir, json_file), "r") as infile:
|
||||
data = json.load(infile)
|
||||
combined_data.append(data)
|
||||
combined_json_file = f"{base_dir}/final_compiled_report.json"
|
||||
with open(combined_json_file, "w") as outfile:
|
||||
json.dump(combined_data, outfile, indent=4)
|
||||
|
||||
|
||||
def combine_csv_reports():
|
||||
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')]
|
||||
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
|
||||
with open(combined_csv_file, mode='w', newline='') as outfile:
|
||||
csv_writer = csv.writer(outfile)
|
||||
for i, csv_file in enumerate(all_csv_files):
|
||||
with open(os.path.join(base_dir, csv_file), mode='r') as infile:
|
||||
csv_reader = csv.reader(infile)
|
||||
if i == 0:
|
||||
# write headers only once
|
||||
csv_writer.writerow(next(csv_reader))
|
||||
for row in csv_reader:
|
||||
csv_writer.writerow(row)
|
||||
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
|
||||
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
|
||||
with open(combined_csv_file, mode="w", newline="") as outfile:
|
||||
csv_writer = csv.writer(outfile)
|
||||
for i, csv_file in enumerate(all_csv_files):
|
||||
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
|
||||
csv_reader = csv.reader(infile)
|
||||
if i == 0:
|
||||
# write headers only once
|
||||
csv_writer.writerow(next(csv_reader))
|
||||
for row in csv_reader:
|
||||
csv_writer.writerow(row)
|
||||
|
||||
|
||||
def generate_final_report(shell=False, env_vars={}):
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
|
||||
result = subprocess.run(cmd,
|
||||
shell=shell,
|
||||
capture_output=True,
|
||||
env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
cmd = [
|
||||
"pytest_html_merger",
|
||||
"-i",
|
||||
f"{base_dir}",
|
||||
"-o",
|
||||
f"{base_dir}/final_compiled_report.html",
|
||||
]
|
||||
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
|
||||
# Generate json reports.
|
||||
combine_json_reports()
|
||||
# Generate csv reports.
|
||||
combine_csv_reports()
|
||||
# Generate json reports.
|
||||
combine_json_reports()
|
||||
# Generate csv reports.
|
||||
combine_csv_reports()
|
||||
|
||||
|
||||
def run_shell_command(cmd, shell=False, env_vars={}):
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
result = subprocess.run(cmd,
|
||||
shell=shell,
|
||||
capture_output=True,
|
||||
env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
|
||||
return result.returncode, result.stderr.decode(), result.stdout.decode()
|
||||
return result.returncode, result.stderr.decode(), result.stdout.decode()
|
||||
|
||||
|
||||
def parse_test_log(log_file):
|
||||
"""Parses the test module log file to extract test modules and functions."""
|
||||
test_files = set()
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
report = json.loads(line)
|
||||
if "nodeid" in report:
|
||||
module = report["nodeid"].split("::")[0]
|
||||
if module and ".py" in module:
|
||||
test_files.add(os.path.abspath(module))
|
||||
return test_files
|
||||
"""Parses the test module log file to extract test modules and functions."""
|
||||
test_files = set()
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
report = json.loads(line)
|
||||
if "nodeid" in report:
|
||||
module = report["nodeid"].split("::")[0]
|
||||
if module and ".py" in module:
|
||||
test_files.add(os.path.abspath(module))
|
||||
return test_files
|
||||
|
||||
|
||||
def collect_testmodules():
|
||||
log_file = f"{base_dir}/collect_module_log.jsonl"
|
||||
return_code, stderr, stdout = run_shell_command(
|
||||
["python3", "-m", "pytest", "--collect-only", "tests", f"--report-log={log_file}"])
|
||||
if return_code != 0:
|
||||
print("Test module discovery failed.")
|
||||
print("STDOUT:", stdout)
|
||||
print("STDERR:", stderr)
|
||||
exit(return_code)
|
||||
print("---------- collected test modules ----------")
|
||||
test_files = parse_test_log(log_file)
|
||||
print("Found %d test modules." % (len(test_files)))
|
||||
print("--------------------------------------------")
|
||||
print("\n".join(test_files))
|
||||
return test_files
|
||||
log_file = f"{base_dir}/collect_module_log.jsonl"
|
||||
return_code, stderr, stdout = run_shell_command(
|
||||
[
|
||||
"python3",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--collect-only",
|
||||
"tests",
|
||||
f"--report-log={log_file}",
|
||||
]
|
||||
)
|
||||
if return_code != 0:
|
||||
print("Test module discovery failed.")
|
||||
print("STDOUT:", stdout)
|
||||
print("STDERR:", stderr)
|
||||
exit(return_code)
|
||||
print("---------- collected test modules ----------")
|
||||
test_files = parse_test_log(log_file)
|
||||
print("Found %d test modules." % (len(test_files)))
|
||||
print("--------------------------------------------")
|
||||
print("\n".join(test_files))
|
||||
return test_files
|
||||
|
||||
|
||||
def run_test(testmodule, gpu_tokens, continue_on_fail):
|
||||
global LAST_CODE
|
||||
with GPU_LOCK:
|
||||
if LAST_CODE != 0:
|
||||
return
|
||||
target_gpu = gpu_tokens.pop()
|
||||
env_vars = {
|
||||
"HIP_VISIBLE_DEVICES": str(target_gpu),
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
|
||||
}
|
||||
testfile = extract_filename(testmodule)
|
||||
if continue_on_fail:
|
||||
cmd = ["python3", "-m", "pytest",
|
||||
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns", "id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns", "3", "-v", testmodule]
|
||||
else:
|
||||
cmd = ["python3", "-m", "pytest",
|
||||
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns", "id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns", "3", "-x", "-v", testmodule]
|
||||
global LAST_CODE
|
||||
with GPU_LOCK:
|
||||
if LAST_CODE != 0:
|
||||
return
|
||||
target_gpu = gpu_tokens.pop()
|
||||
env_vars = {
|
||||
"HIP_VISIBLE_DEVICES": str(target_gpu),
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
|
||||
}
|
||||
testfile = extract_filename(testmodule)
|
||||
if continue_on_fail:
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--json-report",
|
||||
f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns",
|
||||
"id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns",
|
||||
"3",
|
||||
"-v",
|
||||
testmodule,
|
||||
]
|
||||
else:
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--json-report",
|
||||
f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns",
|
||||
"id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns",
|
||||
"3",
|
||||
"-x",
|
||||
"-v",
|
||||
testmodule,
|
||||
]
|
||||
|
||||
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
|
||||
with GPU_LOCK:
|
||||
gpu_tokens.append(target_gpu)
|
||||
if LAST_CODE == 0:
|
||||
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
|
||||
print(stdout)
|
||||
print(stderr)
|
||||
if continue_on_fail == False:
|
||||
LAST_CODE = return_code
|
||||
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
|
||||
with GPU_LOCK:
|
||||
gpu_tokens.append(target_gpu)
|
||||
if LAST_CODE == 0:
|
||||
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
|
||||
print(stdout)
|
||||
print(stderr)
|
||||
if continue_on_fail == False:
|
||||
LAST_CODE = return_code
|
||||
|
||||
|
||||
def run_parallel(all_testmodules, p, c):
|
||||
print(f"Running tests with parallelism = {p}")
|
||||
available_gpu_tokens = list(range(p))
|
||||
executor = ThreadPoolExecutor(max_workers=p)
|
||||
# walking through test modules.
|
||||
for testmodule in all_testmodules:
|
||||
executor.submit(run_test, testmodule, available_gpu_tokens, c)
|
||||
# waiting for all modules to finish.
|
||||
executor.shutdown(wait=True)
|
||||
print(f"Running tests with parallelism = {p}")
|
||||
available_gpu_tokens = list(range(p))
|
||||
executor = ThreadPoolExecutor(max_workers=p)
|
||||
# walking through test modules.
|
||||
for testmodule in all_testmodules:
|
||||
executor.submit(run_test, testmodule, available_gpu_tokens, c)
|
||||
# waiting for all modules to finish.
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
|
||||
def find_num_gpus():
|
||||
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
|
||||
_, _, stdout = run_shell_command(cmd, shell=True)
|
||||
return int(stdout)
|
||||
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
|
||||
_, _, stdout = run_shell_command(cmd, shell=True)
|
||||
return int(stdout)
|
||||
|
||||
|
||||
def main(args):
|
||||
all_testmodules = collect_testmodules()
|
||||
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
|
||||
generate_final_report()
|
||||
exit(LAST_CODE)
|
||||
all_testmodules = collect_testmodules()
|
||||
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
|
||||
generate_final_report()
|
||||
exit(LAST_CODE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p",
|
||||
"--parallel",
|
||||
type=int,
|
||||
help="number of tests to run in parallel")
|
||||
parser.add_argument("-c",
|
||||
"--continue_on_fail",
|
||||
action='store_true',
|
||||
help="continue on failure")
|
||||
args = parser.parse_args()
|
||||
if args.continue_on_fail:
|
||||
print("continue on fail is set")
|
||||
if args.parallel is None:
|
||||
sys_gpu_count = find_num_gpus()
|
||||
args.parallel = sys_gpu_count
|
||||
print("%d GPUs detected." % sys_gpu_count)
|
||||
if __name__ == "__main__":
|
||||
os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-p", "--parallel", type=int, help="number of tests to run in parallel"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--continue_on_fail", action="store_true", help="continue on failure"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.continue_on_fail:
|
||||
print("continue on fail is set")
|
||||
if args.parallel is None:
|
||||
sys_gpu_count = find_num_gpus()
|
||||
args.parallel = sys_gpu_count
|
||||
print("%d GPUs detected." % sys_gpu_count)
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
@ -116,6 +116,7 @@ def build_jaxlib_wheel(
|
||||
if compiler == "clang":
|
||||
clang_path = find_clang_path()
|
||||
if clang_path:
|
||||
LOG.info("Found clang at path: %s", clang_path)
|
||||
cmd.append("--clang_path=%s" % clang_path)
|
||||
else:
|
||||
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
|
||||
@ -315,6 +316,21 @@ def main():
|
||||
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
|
||||
shutil.copy(whl, wheelhouse_dir)
|
||||
|
||||
# Delete the 'dist' directory since it causes permissions issues
|
||||
logging.info('Deleting dist, egg-info and cache directory')
|
||||
shutil.rmtree(os.path.join(args.jax_path, "dist"))
|
||||
shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info"))
|
||||
shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__"))
|
||||
|
||||
# Make the wheels deleteable by the runner
|
||||
whl_house = os.path.join(args.jax_path, "wheelhouse")
|
||||
logging.info("Changing permissions for %s" % whl_house)
|
||||
mode = 0o664
|
||||
for item in os.listdir(whl_house):
|
||||
whl_path = os.path.join(whl_house, item)
|
||||
if os.path.isfile(whl_path):
|
||||
os.chmod(whl_path, mode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
6
third_party/xla/workspace.bzl
vendored
6
third_party/xla/workspace.bzl
vendored
@ -21,15 +21,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "1a6361a734c5cd10dc93938fc6163a51fd37b82e"
|
||||
XLA_SHA256 = "01159fd52f0e402829a3823472a309562817c72d0212f81cd5555f77394c094f"
|
||||
XLA_COMMIT = "373f359cbd8d02ee850d98fed92a7bbca4a09c1b"
|
||||
XLA_SHA256 = "bccda939edabf6723fcb9e59b833288d66ff93b6f34902c28c521a0b39b52d83"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
name = "xla",
|
||||
sha256 = XLA_SHA256,
|
||||
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
|
||||
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
|
||||
urls = tf_mirror_urls("https://github.com/rocm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
|
||||
)
|
||||
|
||||
# For development, one often wants to make changes to the TF repository as well
|
||||
|
Loading…
x
Reference in New Issue
Block a user