mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge branch 'rocm-main' into ci-upstream-sync-97_1
This commit is contained in:
commit
63e6442bdf
37
.github/workflows/ci-build.yaml
vendored
37
.github/workflows/ci-build.yaml
vendored
@ -1,4 +1,4 @@
|
||||
name: CI
|
||||
name: ROCm CPU CI
|
||||
|
||||
# We test all supported Python versions as follows:
|
||||
# - 3.10 : Documentation build
|
||||
@ -11,10 +11,10 @@ on:
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- rocm-main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- rocm-main
|
||||
|
||||
permissions:
|
||||
contents: read # to fetch code
|
||||
@ -42,12 +42,8 @@ jobs:
|
||||
- run: pre-commit run --show-diff-on-failure --color=always --all-files
|
||||
|
||||
build:
|
||||
# Don't execute in fork due to runner type
|
||||
if: github.repository == 'jax-ml/jax'
|
||||
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
|
||||
runs-on: linux-x86-n2-32
|
||||
container:
|
||||
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
|
||||
runs-on: ROCM-Ubuntu
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
@ -65,10 +61,6 @@ jobs:
|
||||
num_generated_cases: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Image Setup
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libssl-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
@ -109,7 +101,7 @@ jobs:
|
||||
|
||||
documentation:
|
||||
name: Documentation - test code snippets
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ROCM-Ubuntu
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
@ -146,19 +138,13 @@ jobs:
|
||||
|
||||
documentation_render:
|
||||
name: Documentation - render documentation
|
||||
runs-on: linux-x86-n2-16
|
||||
container:
|
||||
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10']
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Image Setup
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libssl-dev libsqlite3-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
@ -229,9 +215,7 @@ jobs:
|
||||
|
||||
ffi:
|
||||
name: FFI example
|
||||
runs-on: linux-x86-g2-16-l4-1gpu
|
||||
container:
|
||||
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
|
||||
runs-on: ROCM-Ubuntu
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -250,7 +234,7 @@ jobs:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
|
||||
- name: Install JAX
|
||||
run: pip install .[cuda12]
|
||||
run: pip install .
|
||||
- name: Build and install example project
|
||||
run: python -m pip install -v ./examples/ffi[test]
|
||||
env:
|
||||
@ -259,10 +243,11 @@ jobs:
|
||||
# a different toolchain. GCC is the default compiler on the
|
||||
# 'ubuntu-latest' runner, but we still set this explicitly just to be
|
||||
# clear.
|
||||
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
|
||||
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
|
||||
- name: Run CPU tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
env:
|
||||
JAX_PLATFORM_NAME: cpu
|
||||
- name: Run GPU tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
|
||||
|
63
.github/workflows/rocm-ci.yml
vendored
Normal file
63
.github/workflows/rocm-ci.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: ROCm GPU CI
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the rocm-main branch
|
||||
push:
|
||||
branches:
|
||||
- rocm-main
|
||||
pull_request:
|
||||
branches:
|
||||
- rocm-main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-jax-in-docker: # strategy and matrix come here
|
||||
runs-on: mi-250
|
||||
env:
|
||||
BASE_IMAGE: "ubuntu:22.04"
|
||||
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
PYTHON_VERSION: "3.10"
|
||||
ROCM_VERSION: "6.2.4"
|
||||
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
steps:
|
||||
- name: Clean up old runs
|
||||
run: |
|
||||
ls
|
||||
# Make sure that we own all of the files so that we have permissions to delete them
|
||||
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
|
||||
# Remove any old work directories from this machine
|
||||
rm -rf workdir_*
|
||||
ls
|
||||
- name: Print system info
|
||||
run: |
|
||||
whoami
|
||||
printenv
|
||||
df -h
|
||||
rocm-smi
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
path: ${{ env.WORKSPACE_DIR }}
|
||||
- name: Build JAX
|
||||
run: |
|
||||
pushd $WORKSPACE_DIR
|
||||
python3 build/rocm/ci_build \
|
||||
--rocm-version $ROCM_VERSION \
|
||||
--base-docker $BASE_IMAGE \
|
||||
--python-versions $PYTHON_VERSION \
|
||||
--compiler=clang \
|
||||
dist_docker \
|
||||
--image-tag $TEST_IMAGE
|
||||
- name: Archive jax wheels
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
|
||||
path: ./dist/*.whl
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd $WORKSPACE_DIR
|
||||
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"
|
||||
|
52
.github/workflows/rocm-nightly-upstream-sync.yml
vendored
Normal file
52
.github/workflows/rocm-nightly-upstream-sync.yml
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
# Pulls the latest changes from upstream into main and opens a PR to merge
|
||||
# them into rocm-main branch.
|
||||
|
||||
name: ROCm Nightly Upstream Sync
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 6 * * 1-5'
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
env:
|
||||
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
|
||||
jobs:
|
||||
sync-main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: |
|
||||
gh auth status
|
||||
gh repo sync rocm/jax -b main
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
create-sync-branch:
|
||||
needs: sync-main
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Create branch
|
||||
run: |
|
||||
git fetch
|
||||
git checkout origin/main
|
||||
git checkout -b $SYNC_BRANCH_NAME
|
||||
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
|
||||
git config --global user.email "github-actions@github.com"
|
||||
git config --global user.name "GitHub Actions"
|
||||
git merge origin/rocm-main || true
|
||||
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
|
||||
git merge --abort || true
|
||||
git push origin HEAD
|
||||
open-sync-pr:
|
||||
needs: create-sync-branch
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: |
|
||||
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
|
||||
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
41
.github/workflows/rocm-open-upstream-pr.yml
vendored
Normal file
41
.github/workflows/rocm-open-upstream-pr.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
name: ROCm Open Upstream PR
|
||||
on:
|
||||
pull_request:
|
||||
types: [ labeled ]
|
||||
branches: [ rocm-main ]
|
||||
jobs:
|
||||
open-upstream:
|
||||
if: ${{ github.event.label.name == 'open-upstream' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Rebase code to main
|
||||
run: |
|
||||
git config --global user.email "github-actions@github.com"
|
||||
git config --global user.name "Github Actions"
|
||||
git fetch
|
||||
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
|
||||
git rebase --onto origin/main origin/rocm-main
|
||||
# Force push here so that we don't run into conflicts with the origin branch
|
||||
git push origin HEAD --force
|
||||
- name: Leave link to create PR
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# Bash is not friendly with newline characters, so make our own
|
||||
NL=$'\n'
|
||||
# Encode the PR title and body for passing as URL get parameters
|
||||
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
|
||||
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
|
||||
# Create a link to the that will open up a new PR form to upstream and autofill the fields
|
||||
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
|
||||
# Add a comment with the link to the PR
|
||||
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
|
||||
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"
|
||||
|
2
.github/workflows/upstream-nightly.yml
vendored
2
.github/workflows/upstream-nightly.yml
vendored
@ -22,7 +22,7 @@ on:
|
||||
|
||||
jobs:
|
||||
upstream-dev:
|
||||
runs-on: ubuntu-20.04-16core
|
||||
runs-on: ROCM-Ubuntu
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write # for failed-build-issue
|
||||
|
@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
|
||||
### Simplified Build Script
|
||||
|
||||
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
|
||||
|
||||
|
@ -5,7 +5,11 @@ ARG ROCM_BUILD_JOB
|
||||
ARG ROCM_BUILD_NUM
|
||||
|
||||
# Install system GCC and C++ libraries.
|
||||
RUN yum install -y gcc-c++.x86_64
|
||||
# (charleshofer) This is not ideal, as we should already have GCC and C++ libraries in the
|
||||
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
|
||||
# to find them.
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
@ -20,3 +24,6 @@ RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
|
||||
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
|
||||
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project
|
||||
|
||||
# Stop git from erroring out when we don't own the repo
|
||||
RUN git config --global --add safe.directory '*'
|
||||
|
@ -21,11 +21,15 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
LOG = logging.getLogger("ci_build")
|
||||
|
||||
|
||||
def image_by_name(name):
|
||||
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
|
||||
out = subprocess.check_output(cmd)
|
||||
@ -33,6 +37,25 @@ def image_by_name(name):
|
||||
return image_id
|
||||
|
||||
|
||||
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
|
||||
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
||||
cmd = [
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
||||
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
||||
"--tag=%s" % image_name,
|
||||
".",
|
||||
]
|
||||
|
||||
LOG.info("Creating manylinux build image. Running: %s", cmd)
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
return image_name
|
||||
|
||||
|
||||
def dist_wheels(
|
||||
rocm_version,
|
||||
python_versions,
|
||||
@ -41,34 +64,13 @@ def dist_wheels(
|
||||
rocm_build_num="",
|
||||
compiler="gcc",
|
||||
):
|
||||
# We want to make sure the wheels we build are manylinux compliant. We'll
|
||||
# do the build in a container. Build the image for this.
|
||||
image_name = create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num)
|
||||
|
||||
if xla_path:
|
||||
xla_path = os.path.abspath(xla_path)
|
||||
|
||||
# create manylinux image with requested ROCm installed
|
||||
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
||||
|
||||
# Try removing the Docker image.
|
||||
try:
|
||||
subprocess.run(["docker", "rmi", image], check=True)
|
||||
print(f"Image {image} removed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to remove Docker image {image}: {e}")
|
||||
|
||||
cmd = [
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
||||
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
||||
"--tag=%s" % image,
|
||||
".",
|
||||
]
|
||||
|
||||
if not image_by_name(image):
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
# use image to build JAX/jaxlib wheels
|
||||
os.makedirs("wheelhouse", exist_ok=True)
|
||||
|
||||
@ -114,13 +116,14 @@ def dist_wheels(
|
||||
[
|
||||
"--init",
|
||||
"--rm",
|
||||
image,
|
||||
image_name,
|
||||
"bash",
|
||||
"-c",
|
||||
" ".join(bw_cmd),
|
||||
]
|
||||
)
|
||||
|
||||
LOG.info("Running: %s", cmd)
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
@ -141,10 +144,16 @@ def _fetch_jax_metadata(xla_path):
|
||||
|
||||
jax_version = subprocess.check_output(cmd, env=env)
|
||||
|
||||
def safe_decode(x):
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
else:
|
||||
return x.decode("utf8")
|
||||
|
||||
return {
|
||||
"jax_version": jax_version.decode("utf8").strip(),
|
||||
"jax_commit": jax_commit.decode("utf8").strip(),
|
||||
"xla_commit": xla_commit.decode("utf8").strip(),
|
||||
"jax_version": safe_decode(jax_version).strip(),
|
||||
"jax_commit": safe_decode(jax_commit).strip(),
|
||||
"xla_commit": safe_decode(xla_commit).strip(),
|
||||
}
|
||||
|
||||
|
||||
@ -193,7 +202,7 @@ def dist_docker(
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def test(image_name):
|
||||
def test(image_name, test_cmd):
|
||||
"""Run unit tests like CI would inside a JAX image."""
|
||||
|
||||
gpu_args = [
|
||||
@ -211,10 +220,12 @@ def test(image_name):
|
||||
cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"-it",
|
||||
"--rm",
|
||||
]
|
||||
|
||||
if os.isatty(sys.stdout.fileno()):
|
||||
cmd.append("-it")
|
||||
|
||||
# NOTE(mrodden): we need jax source dir for the unit test code only,
|
||||
# JAX and jaxlib are already installed from wheels
|
||||
mounts = [
|
||||
@ -225,7 +236,7 @@ def test(image_name):
|
||||
cmd.extend(mounts)
|
||||
cmd.extend(gpu_args)
|
||||
|
||||
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh"
|
||||
container_cmd = "cd /jax && " + test_cmd
|
||||
cmd.append(image_name)
|
||||
cmd.extend(
|
||||
[
|
||||
@ -288,6 +299,7 @@ def parse_args():
|
||||
|
||||
testp = subp.add_parser("test")
|
||||
testp.add_argument("image_name")
|
||||
testp.add_argument("--test-cmd", default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh")
|
||||
|
||||
ddp = subp.add_parser("dist_docker")
|
||||
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
|
||||
@ -298,6 +310,7 @@ def parse_args():
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = parse_args()
|
||||
|
||||
if args.action == "dist_wheels":
|
||||
@ -310,7 +323,7 @@ def main():
|
||||
)
|
||||
|
||||
elif args.action == "test":
|
||||
test(args.image_name)
|
||||
test(args.image_name, args.test_cmd)
|
||||
|
||||
elif args.action == "dist_docker":
|
||||
dist_wheels(
|
||||
|
@ -25,179 +25,205 @@ GPU_LOCK = threading.Lock()
|
||||
LAST_CODE = 0
|
||||
base_dir = "./logs"
|
||||
|
||||
|
||||
def extract_filename(path):
|
||||
base_name = os.path.basename(path)
|
||||
file_name, _ = os.path.splitext(base_name)
|
||||
return file_name
|
||||
base_name = os.path.basename(path)
|
||||
file_name, _ = os.path.splitext(base_name)
|
||||
return file_name
|
||||
|
||||
|
||||
def combine_json_reports():
|
||||
all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')]
|
||||
combined_data = []
|
||||
for json_file in all_json_files:
|
||||
with open(os.path.join(base_dir, json_file), 'r') as infile:
|
||||
data = json.load(infile)
|
||||
combined_data.append(data)
|
||||
combined_json_file = f"{base_dir}/final_compiled_report.json"
|
||||
with open(combined_json_file, 'w') as outfile:
|
||||
json.dump(combined_data, outfile, indent=4)
|
||||
all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
|
||||
combined_data = []
|
||||
for json_file in all_json_files:
|
||||
with open(os.path.join(base_dir, json_file), "r") as infile:
|
||||
data = json.load(infile)
|
||||
combined_data.append(data)
|
||||
combined_json_file = f"{base_dir}/final_compiled_report.json"
|
||||
with open(combined_json_file, "w") as outfile:
|
||||
json.dump(combined_data, outfile, indent=4)
|
||||
|
||||
|
||||
def combine_csv_reports():
|
||||
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')]
|
||||
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
|
||||
with open(combined_csv_file, mode='w', newline='') as outfile:
|
||||
csv_writer = csv.writer(outfile)
|
||||
for i, csv_file in enumerate(all_csv_files):
|
||||
with open(os.path.join(base_dir, csv_file), mode='r') as infile:
|
||||
csv_reader = csv.reader(infile)
|
||||
if i == 0:
|
||||
# write headers only once
|
||||
csv_writer.writerow(next(csv_reader))
|
||||
for row in csv_reader:
|
||||
csv_writer.writerow(row)
|
||||
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
|
||||
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
|
||||
with open(combined_csv_file, mode="w", newline="") as outfile:
|
||||
csv_writer = csv.writer(outfile)
|
||||
for i, csv_file in enumerate(all_csv_files):
|
||||
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
|
||||
csv_reader = csv.reader(infile)
|
||||
if i == 0:
|
||||
# write headers only once
|
||||
csv_writer.writerow(next(csv_reader))
|
||||
for row in csv_reader:
|
||||
csv_writer.writerow(row)
|
||||
|
||||
|
||||
def generate_final_report(shell=False, env_vars={}):
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
|
||||
result = subprocess.run(cmd,
|
||||
shell=shell,
|
||||
capture_output=True,
|
||||
env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
cmd = [
|
||||
"pytest_html_merger",
|
||||
"-i",
|
||||
f"{base_dir}",
|
||||
"-o",
|
||||
f"{base_dir}/final_compiled_report.html",
|
||||
]
|
||||
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
|
||||
# Generate json reports.
|
||||
combine_json_reports()
|
||||
# Generate csv reports.
|
||||
combine_csv_reports()
|
||||
# Generate json reports.
|
||||
combine_json_reports()
|
||||
# Generate csv reports.
|
||||
combine_csv_reports()
|
||||
|
||||
|
||||
def run_shell_command(cmd, shell=False, env_vars={}):
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
result = subprocess.run(cmd,
|
||||
shell=shell,
|
||||
capture_output=True,
|
||||
env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
|
||||
if result.returncode != 0:
|
||||
print("FAILED - {}".format(" ".join(cmd)))
|
||||
print(result.stderr.decode())
|
||||
|
||||
return result.returncode, result.stderr.decode(), result.stdout.decode()
|
||||
return result.returncode, result.stderr.decode(), result.stdout.decode()
|
||||
|
||||
|
||||
def parse_test_log(log_file):
|
||||
"""Parses the test module log file to extract test modules and functions."""
|
||||
test_files = set()
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
report = json.loads(line)
|
||||
if "nodeid" in report:
|
||||
module = report["nodeid"].split("::")[0]
|
||||
if module and ".py" in module:
|
||||
test_files.add(os.path.abspath(module))
|
||||
return test_files
|
||||
"""Parses the test module log file to extract test modules and functions."""
|
||||
test_files = set()
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
report = json.loads(line)
|
||||
if "nodeid" in report:
|
||||
module = report["nodeid"].split("::")[0]
|
||||
if module and ".py" in module:
|
||||
test_files.add(os.path.abspath(module))
|
||||
return test_files
|
||||
|
||||
|
||||
def collect_testmodules():
|
||||
log_file = f"{base_dir}/collect_module_log.jsonl"
|
||||
return_code, stderr, stdout = run_shell_command(
|
||||
["python3", "-m", "pytest", "--collect-only", "tests", f"--report-log={log_file}"])
|
||||
if return_code != 0:
|
||||
print("Test module discovery failed.")
|
||||
print("STDOUT:", stdout)
|
||||
print("STDERR:", stderr)
|
||||
exit(return_code)
|
||||
print("---------- collected test modules ----------")
|
||||
test_files = parse_test_log(log_file)
|
||||
print("Found %d test modules." % (len(test_files)))
|
||||
print("--------------------------------------------")
|
||||
print("\n".join(test_files))
|
||||
return test_files
|
||||
log_file = f"{base_dir}/collect_module_log.jsonl"
|
||||
return_code, stderr, stdout = run_shell_command(
|
||||
[
|
||||
"python3",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--collect-only",
|
||||
"tests",
|
||||
f"--report-log={log_file}",
|
||||
]
|
||||
)
|
||||
if return_code != 0:
|
||||
print("Test module discovery failed.")
|
||||
print("STDOUT:", stdout)
|
||||
print("STDERR:", stderr)
|
||||
exit(return_code)
|
||||
print("---------- collected test modules ----------")
|
||||
test_files = parse_test_log(log_file)
|
||||
print("Found %d test modules." % (len(test_files)))
|
||||
print("--------------------------------------------")
|
||||
print("\n".join(test_files))
|
||||
return test_files
|
||||
|
||||
|
||||
def run_test(testmodule, gpu_tokens, continue_on_fail):
|
||||
global LAST_CODE
|
||||
with GPU_LOCK:
|
||||
if LAST_CODE != 0:
|
||||
return
|
||||
target_gpu = gpu_tokens.pop()
|
||||
env_vars = {
|
||||
"HIP_VISIBLE_DEVICES": str(target_gpu),
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
|
||||
}
|
||||
testfile = extract_filename(testmodule)
|
||||
if continue_on_fail:
|
||||
cmd = ["python3", "-m", "pytest",
|
||||
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns", "id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns", "3", "-v", testmodule]
|
||||
else:
|
||||
cmd = ["python3", "-m", "pytest",
|
||||
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns", "id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns", "3", "-x", "-v", testmodule]
|
||||
global LAST_CODE
|
||||
with GPU_LOCK:
|
||||
if LAST_CODE != 0:
|
||||
return
|
||||
target_gpu = gpu_tokens.pop()
|
||||
env_vars = {
|
||||
"HIP_VISIBLE_DEVICES": str(target_gpu),
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
|
||||
}
|
||||
testfile = extract_filename(testmodule)
|
||||
if continue_on_fail:
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--json-report",
|
||||
f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns",
|
||||
"id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns",
|
||||
"3",
|
||||
"-v",
|
||||
testmodule,
|
||||
]
|
||||
else:
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--json-report",
|
||||
f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns",
|
||||
"id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns",
|
||||
"3",
|
||||
"-x",
|
||||
"-v",
|
||||
testmodule,
|
||||
]
|
||||
|
||||
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
|
||||
with GPU_LOCK:
|
||||
gpu_tokens.append(target_gpu)
|
||||
if LAST_CODE == 0:
|
||||
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
|
||||
print(stdout)
|
||||
print(stderr)
|
||||
if continue_on_fail == False:
|
||||
LAST_CODE = return_code
|
||||
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
|
||||
with GPU_LOCK:
|
||||
gpu_tokens.append(target_gpu)
|
||||
if LAST_CODE == 0:
|
||||
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
|
||||
print(stdout)
|
||||
print(stderr)
|
||||
if continue_on_fail == False:
|
||||
LAST_CODE = return_code
|
||||
|
||||
|
||||
def run_parallel(all_testmodules, p, c):
|
||||
print(f"Running tests with parallelism = {p}")
|
||||
available_gpu_tokens = list(range(p))
|
||||
executor = ThreadPoolExecutor(max_workers=p)
|
||||
# walking through test modules.
|
||||
for testmodule in all_testmodules:
|
||||
executor.submit(run_test, testmodule, available_gpu_tokens, c)
|
||||
# waiting for all modules to finish.
|
||||
executor.shutdown(wait=True)
|
||||
print(f"Running tests with parallelism = {p}")
|
||||
available_gpu_tokens = list(range(p))
|
||||
executor = ThreadPoolExecutor(max_workers=p)
|
||||
# walking through test modules.
|
||||
for testmodule in all_testmodules:
|
||||
executor.submit(run_test, testmodule, available_gpu_tokens, c)
|
||||
# waiting for all modules to finish.
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
|
||||
def find_num_gpus():
|
||||
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
|
||||
_, _, stdout = run_shell_command(cmd, shell=True)
|
||||
return int(stdout)
|
||||
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
|
||||
_, _, stdout = run_shell_command(cmd, shell=True)
|
||||
return int(stdout)
|
||||
|
||||
|
||||
def main(args):
|
||||
all_testmodules = collect_testmodules()
|
||||
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
|
||||
generate_final_report()
|
||||
exit(LAST_CODE)
|
||||
all_testmodules = collect_testmodules()
|
||||
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
|
||||
generate_final_report()
|
||||
exit(LAST_CODE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p",
|
||||
"--parallel",
|
||||
type=int,
|
||||
help="number of tests to run in parallel")
|
||||
parser.add_argument("-c",
|
||||
"--continue_on_fail",
|
||||
action='store_true',
|
||||
help="continue on failure")
|
||||
args = parser.parse_args()
|
||||
if args.continue_on_fail:
|
||||
print("continue on fail is set")
|
||||
if args.parallel is None:
|
||||
sys_gpu_count = find_num_gpus()
|
||||
args.parallel = sys_gpu_count
|
||||
print("%d GPUs detected." % sys_gpu_count)
|
||||
if __name__ == "__main__":
|
||||
os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-p", "--parallel", type=int, help="number of tests to run in parallel"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--continue_on_fail", action="store_true", help="continue on failure"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.continue_on_fail:
|
||||
print("continue on fail is set")
|
||||
if args.parallel is None:
|
||||
sys_gpu_count = find_num_gpus()
|
||||
args.parallel = sys_gpu_count
|
||||
print("%d GPUs detected." % sys_gpu_count)
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
@ -116,6 +116,7 @@ def build_jaxlib_wheel(
|
||||
if compiler == "clang":
|
||||
clang_path = find_clang_path()
|
||||
if clang_path:
|
||||
LOG.info("Found clang at path: %s", clang_path)
|
||||
cmd.append("--clang_path=%s" % clang_path)
|
||||
else:
|
||||
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
|
||||
@ -315,6 +316,21 @@ def main():
|
||||
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
|
||||
shutil.copy(whl, wheelhouse_dir)
|
||||
|
||||
# Delete the 'dist' directory since it causes permissions issues
|
||||
logging.info('Deleting dist, egg-info and cache directory')
|
||||
shutil.rmtree(os.path.join(args.jax_path, "dist"))
|
||||
shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info"))
|
||||
shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__"))
|
||||
|
||||
# Make the wheels deleteable by the runner
|
||||
whl_house = os.path.join(args.jax_path, "wheelhouse")
|
||||
logging.info("Changing permissions for %s" % whl_house)
|
||||
mode = 0o664
|
||||
for item in os.listdir(whl_house):
|
||||
whl_path = os.path.join(whl_house, item)
|
||||
if os.path.isfile(whl_path):
|
||||
os.chmod(whl_path, mode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
53
build/rocm/upload_wheels.sh
Normal file
53
build/rocm/upload_wheels.sh
Normal file
@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Check for user-supplied arguments.
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "Usage: $0 <jax_home_directory> <version>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set JAX_HOME and RELEASE_VERSION from user arguments.
|
||||
JAX_HOME=$1
|
||||
RELEASE_VERSION=$2
|
||||
WHEELHOUSE="$JAX_HOME/wheelhouse"
|
||||
|
||||
# Projects to upload separately to PyPI.
|
||||
PROJECTS=("jax_rocm60_pjrt" "jax_rocm60_plugin")
|
||||
|
||||
# PyPI API Token.
|
||||
PYPI_API_TOKEN=${PYPI_API_TOKEN:-"pypi-replace_with_token"}
|
||||
|
||||
# Ensure the specified JAX_HOME and wheelhouse directories exists.
|
||||
if [[ ! -d "$JAX_HOME" ]]; then
|
||||
echo "Error: The specified JAX_HOME directory does not exist: $JAX_HOME"
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -d "$WHEELHOUSE" ]]; then
|
||||
echo "Error: The wheelhouse directory does not exist: $WHEELHOUSE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
upload_and_release_project() {
|
||||
local project=$1
|
||||
|
||||
echo "Searching for wheels matching project: $project version: $RELEASE_VERSION..."
|
||||
wheels=($(ls $WHEELHOUSE | grep "^${project}-${RELEASE_VERSION}[.-].*\.whl"))
|
||||
if [[ ${#wheels[@]} -eq 0 ]]; then
|
||||
echo "No wheels found for project: $project version: $RELEASE_VERSION. Skipping..."
|
||||
return
|
||||
fi
|
||||
echo "Found wheels for $project: ${wheels[*]}"
|
||||
|
||||
echo "Uploading wheels for $project version $RELEASE_VERSION to PyPI..."
|
||||
for wheel in "${wheels[@]}"; do
|
||||
twine upload --verbose --repository pypi --non-interactive --username "__token__" --password "$PYPI_API_TOKEN" "$WHEELHOUSE/$wheel"
|
||||
done
|
||||
}
|
||||
|
||||
# Install twine if not already installed.
|
||||
python -m pip install --upgrade twine
|
||||
|
||||
# Iterate over each project and upload its wheels.
|
||||
for project in "${PROJECTS[@]}"; do
|
||||
upload_and_release_project $project
|
||||
done
|
65
rocm-downstream-dev-guide.md
Normal file
65
rocm-downstream-dev-guide.md
Normal file
@ -0,0 +1,65 @@
|
||||
# ROCm CI Dev Guide
|
||||
|
||||
This guide lays out how to do some dev operations, what branches live in this repo, and what CI workflows live in this repo.
|
||||
|
||||
# Quick Tips
|
||||
|
||||
1. Always use "Squash and Merge" when merging PRs into `rocm-main` (unless you're merging the daily sync from upstream).
|
||||
2. When submitting a PR to `rocm-main`, make sure that your feature branch started from `rocm-main`. When you started working on your feature, did you do `git checkout rocm-main && git checkout -b <my feature branch>`?
|
||||
3. Always fill out your PR's description with an explanation of why we need this change to give context to your fellow devs (and your future self).
|
||||
4. In the PR description, link to the story or GitHub issue that this PR solves.
|
||||
|
||||
# Processes
|
||||
|
||||
## Making a Change
|
||||
|
||||
1. Clone `rocm/jax` and check out the `rocm-main` branch.
|
||||
2. Create a new feature branch with `git checkout -b <my feature name>`.
|
||||
3. Make your changes on the feature branch and test them locally.
|
||||
4. Push your changes to a new feature branch in `rocm/jax` by running
|
||||
`git push orgin HEAD`.
|
||||
5. Open a PR from your new feature branch into `rocm-main` with a nice description telling your
|
||||
team members what the change is for. Bonus points if you can link it to an issue or story.
|
||||
6. Add reviewers, wait for approval, and make sure CI passes.
|
||||
7. Depending on if your specific change, either:
|
||||
a. If this is a normal, run-of-the-mill change that we want to put upstream, add the
|
||||
`open-upstream` label to your PR and close your PR. In a few minutes, Actions will
|
||||
comment on your PR with a link that lets you open a new PR into upstream. The link will
|
||||
autofill some PR info, and the new PR be created on a new branch that has the same name
|
||||
as your old feature branch, but with the `-upstream` suffix appended to the end of it.
|
||||
If upstream reviewers request some changes to the new PR before merging, you can add
|
||||
or modify commits on the new `-upstream` feature branch.
|
||||
b. If this is an urgent change that we want in `rocm-main` right now but also want upstream,
|
||||
add the `open-upstream` label, merge your PR, and then follow the link that
|
||||
c. If this is a change that we only want to keep in `rocm/jax` and not push into upstream,
|
||||
squash and merge your PR.
|
||||
|
||||
If you submitted your PR upstream with `open-upstream`, you should see your change in `rocm-main`
|
||||
the next time the `ROCm Nightly Upstream Sync` workflow is run and the PR that it creates is
|
||||
merged.
|
||||
|
||||
When using the `open-upstream` label to move changes to upstream, it's best to put the label on the PR when you either close or merge the PR. The GitHub Actions workflow that handles the `open-upstream` label uses `git rebase --onto` to set up the changes destined for upstream. Adding the label and creating this branch long after the PR has been merged or closed can cause merge conflicts with new upstream code and cause the workflow to fail. Adding the label right after creating your PR means that 1) any changes you make to your downstream PR while it is in review won't make it to upstream, and it is up to you to cherry-pick those changes into the upstream branch or remove and re-add the `open-upstream` label to get the Actions workflow to do it for you, and 2) that you're proposing changes to upstream that the rest of the AMD team might still have comments on.
|
||||
|
||||
## Daily Upstream Sync
|
||||
|
||||
Every day, GitHub Actions will attempt to run the `ROCm Nightly Upstream Sync` workflow. This job
|
||||
normally does this on its own, but requires a developer to intervene if there's a merge conflict
|
||||
or if the PR fails CI. Devs should fix or resolve issues with the merge by adding commits to the
|
||||
PR's branch.
|
||||
|
||||
# Branches
|
||||
|
||||
* `rocm-main` - the default "trunk" branch for this repo. Should only be changed submitting PRs to it from feature branches created by devs.
|
||||
* `main` - a copy of `jax-ml/jax:main`. This branch is "read-only" and should only be changed by GitHub Actions.
|
||||
|
||||
# CI Workflows
|
||||
|
||||
We use GitHub Actions to run tests on PRs and to automate some of our
|
||||
development tasks. These all live in `.github/workflows`.
|
||||
|
||||
| Name | File | Trigger | Description |
|
||||
|----------------------------|----------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------|
|
||||
| ROCm GPU CI | `rocm-ci.yml` | Open or commit changes to a PR targeting `rocm-main` | Builds and runs JAX on ROCm for PRs going into `rocm-main` |
|
||||
| ROCm Open Upstream PR | `rocm-open-upstream-pr.yml` | Add the `open-upstream` label to a PR | Copies changes from a PR aimed at `rocm-main` into a new PR aimed at upstream's `main` |
|
||||
| ROCm Nightly Upstream Sync | `rocm-nightly-upstream-sync.yml` | Runs nightly, can be triggered manually via Actions | Opens a PR that merges changes from upstream `main` into our `rocm-main` branch |
|
||||
|
@ -450,6 +450,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
[dict(l_max=l_max, num_z=num_z)
|
||||
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||
|
Loading…
x
Reference in New Issue
Block a user